diff --git a/.github/ISSUE_TEMPLATE/05-performance.yml b/.github/ISSUE_TEMPLATE/05-performance.yml index 829076a1bd46..da0e6c7ada7a 100644 --- a/.github/ISSUE_TEMPLATE/05-performance.yml +++ b/.github/ISSUE_TEMPLATE/05-performance.yml @@ -1,6 +1,7 @@ name: Performance description: issues related to performance title: "[Performance] " +labels: ["performance"] body: - type: markdown attributes: diff --git a/.github/labeler.yml b/.github/labeler.yml index 526d8a643e71..c14e2a213bc6 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,20 +1,25 @@ -api:javascript: '/\bjavascript\b/i' +api:CSharp: '/(\bc\s*sharp\b|\bc#)/i' api:java: '/\bjava\b/i' +api:javascript: '/\bjavascript\b/i' ep:ACL: '/\bacl\b/i' ep:ArmNN: '/\barmnn\b/i' -ep:CUDA: '/\bcuda\b/i' -ep:DML: '/(\bdirectml\b|\bdml\b)/i' -ep:MIGraphX: '/\bmigraphx\b/i' -ep:oneDNN: '/\bonednn\b/i' +ep:CANN: '/\bcann\b/i' +ep:CoreML: '/\bcore\s*ml\b/i' +ep:DML: '/(\bdirect\s*ml\b|\bdml\b)/i' +ep:MIGraphX: '/\bmi\s*graph\s*x\b/i' +ep:oneDNN: '/\bone\s*dnn\b/i' ep:OpenVINO: '/\bopen\s*vino\b/i' -ep:RockchipNPU: '/\brockchip\b/i' +ep:QNN: '/\bqnn\b/i' +ep:RockchipNPU: '/\brockchip(?:npu)?\b/i' ep:ROCm: '/\brocm\b/i' -ep:TensorRT: '/(\btensor\s*rt\b|\btrt\b)/i' +ep:SNPE: '/\bsnpe\b/i' ep:tvm: '/\btvm\b/i' ep:VitisAI: '/\bvitis(?:ai)?\b/i' -platform:jetson: '/\bjetson\b/i' -platform:mobile: '/(\bobj(?:ective)?-?c\b|\bnnapi\b|\bcore-?ml\b|\bmobile\b|\bandroid\b|\bios\b|\bxamarin\b|\bmaui\b)/i' -platform:web: '/(\bwebgl\b|\bweb-?gpu\b|\bwasm\b|\bonnxruntime-node\b|\bonnxruntime-web\b)/i' -platform:windows: '/(\bwindows\b|\bwinrt\b|\bwinml\b)/i' -model:transformer: '/(\bbert\b|\bgpt-?2\b|\bhugging-?face\b|\blong-?former\b|\bt5\b)/i' -quantization: '/(is this a quantized model\?\n\nYes|\bquantization\b)/i' +ep:WebGPU: '/\bwebgpu\b/i' +ep:WebNN: '/\bwebnn\b/i' +ep:Xnnpack: '/\bxnn\s*pack\b/i' +.NET: '/(\bdot\s*net\b|\bnuget\b|\.net\b)/i' +platform:jetson: '/(\bjetson\b|\bjetpack\b)/i' +platform:mobile: '/(\bobj(?:ective)?-?c\b|\bnnapi\b|\bmobile\b|\bandroid\b|\bios\b|\bxamarin\b|\bmaui\b)/i' +platform:web: '/(\bwebgl\b|\bweb-?gpu\b|\bwasm\b|\bonnxruntime-node\b|\bonnxruntime-web\b|\bonnxruntime-react-native\b|\bnpm\b|\btransformers\.js\b)/i' +model:transformer: '/\btransformers(?!\.js)\b/i' diff --git a/.github/policies/issueLabeler.yml b/.github/policies/issueLabeler.yml new file mode 100644 index 000000000000..45ea07484ddc --- /dev/null +++ b/.github/policies/issueLabeler.yml @@ -0,0 +1,318 @@ +name: Issue Triage +description: Assign label to issues +resource: repository +configuration: + resourceManagementConfiguration: + eventResponderTasks: + - if: + - payloadType: Issues + - and: + - isOpen + - not: + and: + - isAssignedToSomeone + - isLabeled + then: + - if: + - titleContains: + pattern: '/\bcuda\b/i' + isRegex: True + then: + - addLabel: + label: ep:CUDA + - if: + - or: + - titleContains: + pattern: '/(\bc\s*sharp\b|\bc#)/i' + isRegex: True + - bodyContains: + pattern: '/(\bc\s*sharp\b|\bc#)/i' + isRegex: True + then: + - addLabel: + label: api:CSharp + - if: + - or: + - titleContains: + pattern: '/\bjava\b/i' + isRegex: True + - bodyContains: + pattern: '/\bjava\b/i' + isRegex: True + then: + - addLabel: + label: api:Java + - if: + - or: + - titleContains: + pattern: '/\bjavascript\b/i' + isRegex: True + - bodyContains: + pattern: '/\bjavascript\b/i' + isRegex: True + then: + - addLabel: + label: api:JavaScript + - if: + - or: + - titleContains: + pattern: '/\bacl\b/i' + isRegex: True + - bodyContains: + pattern: '/\bacl\b/i' + isRegex: True + then: + - addLabel: + label: ep:ACL + - if: + - or: + - titleContains: + pattern: '/\barmnn\b/i' + isRegex: True + - bodyContains: + pattern: '/\barmnn\b/i' + isRegex: True + then: + - addLabel: + label: ep:ArmNN + - if: + - or: + - titleContains: + pattern: '/\bcann\b/i' + isRegex: True + - bodyContains: + pattern: '/\bcann\b/i' + isRegex: True + then: + - addLabel: + label: ep:CANN + - if: + - or: + - titleContains: + pattern: '/\bcore\s*ml\b/i' + isRegex: True + - bodyContains: + pattern: '/\bcore\s*ml\b/i' + isRegex: True + then: + - addLabel: + label: ep:CoreML + - if: + - or: + - titleContains: + pattern: '/(\bdirect\s*ml\b|\bdml\b)/i' + isRegex: True + - bodyContains: + pattern: '/(\bdirect\s*ml\b|\bdml\b)/i' + isRegex: True + then: + - addLabel: + label: ep:DML + - if: + - or: + - titleContains: + pattern: '/\bmi\s*graph\s*x\b/i' + isRegex: True + - bodyContains: + pattern: '/\bmi\s*graph\s*x\b/i' + isRegex: True + then: + - addLabel: + label: ep:MIGraphX + - if: + - or: + - titleContains: + pattern: '/\bone\s*dnn\b/i' + isRegex: True + - bodyContains: + pattern: '/\bone\s*dnn\b/i' + isRegex: True + then: + - addLabel: + label: ep:oneDNN + - if: + - or: + - titleContains: + pattern: '/\bopen\s*vino\b/i' + isRegex: True + - bodyContains: + pattern: '/\bopen\s*vino\b/i' + isRegex: True + then: + - addLabel: + label: ep:OpenVINO + - if: + - or: + - titleContains: + pattern: '/\bqnn\b/i' + isRegex: True + - bodyContains: + pattern: '/\bqnn\b/i' + isRegex: True + then: + - addLabel: + label: ep:QNN + - if: + - or: + - titleContains: + pattern: '/\brockchip(?:npu)?\b/i' + isRegex: True + - bodyContains: + pattern: '/\brockchip(?:npu)?\b/i' + isRegex: True + then: + - addLabel: + label: ep:RockchipNPU + - if: + - or: + - titleContains: + pattern: '/\brocm\b/i' + isRegex: True + - bodyContains: + pattern: '/\brocm\b/i' + isRegex: True + then: + - addLabel: + label: ep:ROCm + - if: + - or: + - titleContains: + pattern: '/\bsnpe\b/i' + isRegex: True + - bodyContains: + pattern: '/\bsnpe\b/i' + isRegex: True + then: + - addLabel: + label: ep:SNPE + - if: + - titleContains: + pattern: '/(\btensor\s*rt\b|\btrt\b)/i' + isRegex: True + then: + - addLabel: + label: ep:TensorRT + - if: + - or: + - titleContains: + pattern: '/\btvm\b/i' + isRegex: True + - bodyContains: + pattern: '/\btvm\b/i' + isRegex: True + then: + - addLabel: + label: ep:tvm + - if: + - or: + - titleContains: + pattern: '/\bvitis(?:ai)?\b/i' + isRegex: True + - bodyContains: + pattern: '/\bvitis(?:ai)?\b/i' + isRegex: True + then: + - addLabel: + label: ep:VitisAI + - if: + - or: + - titleContains: + pattern: '/\bwebgpu\b/i' + isRegex: True + - bodyContains: + pattern: '/\bwebgpu\b/i' + isRegex: True + then: + - addLabel: + label: ep:WebGPU + - if: + - or: + - titleContains: + pattern: '/\bwebnn\b/i' + isRegex: True + - bodyContains: + pattern: '/\bwebnn\b/i' + isRegex: True + then: + - addLabel: + label: ep:WebNN + - if: + - or: + - titleContains: + pattern: '/\bxnn\s*pack\b/i' + isRegex: True + - bodyContains: + pattern: '/\bxnn\s*pack\b/i' + isRegex: True + then: + - addLabel: + label: ep:Xnnpack + - if: + - or: + - titleContains: + pattern: '/(\bjetson\b|\bjetpack\b)/i' + isRegex: True + - bodyContains: + pattern: '/(\bjetson\b|\bjetpack\b)/i' + isRegex: True + then: + - addLabel: + label: platform:jetson + - if: + - or: + - titleContains: + pattern: '/(\bdot\s*net\b|\bnuget\b|\.net\b)/i' + isRegex: True + - bodyContains: + pattern: '/(\bdot\s*net\b|\bnuget\b|\.net\b)/i' + isRegex: True + then: + - addLabel: + label: .NET + - if: + - or: + - titleContains: + pattern: '/(\bobj(?:ective)?-?c\b|\bnnapi\b|\bmobile\b|\bandroid\b|\bios\b|\bxamarin\b|\bmaui\b)/i' + isRegex: True + - bodyContains: + pattern: '/(\bobj(?:ective)?-?c\b|\bnnapi\b|\bmobile\b|\bandroid\b|\bios\b|\bxamarin\b|\bmaui\b)/i' + isRegex: True + then: + - addLabel: + label: platform:mobile + - if: + - or: + - titleContains: + pattern: '/(\bwebgl\b|\bweb-?gpu\b|\bwasm\b|\bonnxruntime-node\b|\bonnxruntime-web\b|\bonnxruntime-react-native\b|\bnpm\b|\btransformers\.js\b)/i' + isRegex: True + - bodyContains: + pattern: '/(\bwebgl\b|\bweb-?gpu\b|\bwasm\b|\bonnxruntime-node\b|\bonnxruntime-web\b|\bonnxruntime-react-native\b|\bnpm\b|\btransformers\.js\b)/i' + isRegex: True + then: + - addLabel: + label: platform:web + - if: + - titleContains: + pattern: '/(\bwindows\b|\bwinrt\b|\bwinml\b)/i' + isRegex: True + then: + - addLabel: + label: platform:windows + - if: + - or: + - titleContains: + pattern: '/\btransformers(?!\.js)\b/i' + isRegex: True + - bodyContains: + pattern: '/\btransformers(?!\.js)\b/i' + isRegex: True + then: + - addLabel: + label: model:transformer + - if: + - titleContains: + pattern: '/(quant|\bqdq\b)/i' + isRegex: True + then: + - addLabel: + label: quantization diff --git a/.github/title-only-labeler.yml b/.github/title-only-labeler.yml new file mode 100644 index 000000000000..4980f7251bcb --- /dev/null +++ b/.github/title-only-labeler.yml @@ -0,0 +1,4 @@ +ep:CUDA: '/\bcuda\b/i' +ep:TensorRT: '/(\btensor\s*rt\b|\btrt\b)/i' +platform:windows: '/(\bwindows\b|\bwinrt\b|\bwinml\b)/i' +quantization: '/(quant|\bqdq\b)/i' diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 3965fe063b14..2edbe2d81453 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -73,7 +73,7 @@ jobs: checkout_path: ${{ github.workspace }} lint-cpp: - name: Lint C++ + name: Optional Lint C++ runs-on: ubuntu-latest steps: - uses: actions/checkout@master @@ -89,10 +89,11 @@ jobs: - name: Generate ONNX protobuf files run: cmake --build build/Debug --config Debug --target onnx_proto - uses: reviewdog/action-cpplint@master + continue-on-error: true with: github_token: ${{ secrets.github_token }} reporter: github-pr-check - level: warning + level: info flags: --linelength=120 --exclude=java/src/main/native/*.c --exclude=onnxruntime/core/mlas/inc/* diff --git a/.github/workflows/title-only-labeler.yml b/.github/workflows/title-only-labeler.yml new file mode 100644 index 000000000000..e0af2dd06b1b --- /dev/null +++ b/.github/workflows/title-only-labeler.yml @@ -0,0 +1,20 @@ +name: "Title Only Issue Labeler" +on: + issues: + types: [opened, edited] + +permissions: + issues: write + +jobs: + triage: + runs-on: ubuntu-latest + steps: + - uses: github/issue-labeler@v3.4 + with: + repo-token: "${{ secrets.GITHUB_TOKEN }}" + configuration-path: .github/title-only-labeler.yml + not-before: 2020-01-15T02:54:32Z + enable-versioned-regex: 0 + include-title: 1 + include-body: 0 diff --git a/.lintrunner.toml b/.lintrunner.toml index e6d06b34726f..e1b24b2955b0 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -137,6 +137,7 @@ exclude_patterns = [ 'onnxruntime/core/mickey/gemm/**', # CUTLASS based libs recommends NO automatic code formatting 'winml/lib/Api.Image/shaders/**', # Contains data chunks 'onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h', # Bool Switches hang Clang + 'onnxruntime/core/providers/coreml/mlprogram_test_scripts/**', # test scripts only ] command = [ 'python', diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 815d5ca06d53..398935591556 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.19.0 +1.20.0 diff --git a/cgmanifests/generate_cgmanifest.py b/cgmanifests/generate_cgmanifest.py index 52bd3f58645f..b2e8f6816a2e 100644 --- a/cgmanifests/generate_cgmanifest.py +++ b/cgmanifests/generate_cgmanifest.py @@ -73,7 +73,7 @@ def add_github_dep(name, parsed_url): return # Make a REST call to convert to tag to a git commit url = f"https://api.github.com/repos/{org_name}/{repo_name}/git/refs/tags/{tag}" - print("requesting {url} ...") + print(f"requesting {url} ...") res = requests.get(url, auth=(args.username, args.token)) response_json = res.json() tag_object = response_json["object"] diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 66b305a6d36d..f9e702b894f5 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -216,7 +216,7 @@ "component": { "type": "git", "git": { - "commitHash": "06adf4461ac84035bee658c6cf5df39f7ab6071d", + "commitHash": "f161f95883b4ebd8cb789de5efc67b73c0a6e694", "repositoryUrl": "https://github.com/onnx/onnx-tensorrt.git" }, "comments": "onnx_tensorrt" @@ -351,6 +351,16 @@ }, "comments": "directx_headers" } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "98ca4e1941fe3263f128f74f10063a3ea35c7019", + "repositoryUrl": "https://github.com/NVIDIA/cudnn-frontend.git" + }, + "comments": "cudnn_frontend" + } } ] } diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 5555fa692eae..2e9a50e52217 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -729,9 +729,6 @@ set(ORT_PROVIDER_FLAGS) set(ORT_PROVIDER_CMAKE_FLAGS) if (onnxruntime_USE_CUDA) - if (onnxruntime_USE_CUDA_NHWC_OPS) - add_compile_definitions(ENABLE_CUDA_NHWC_OPS) - endif() enable_language(CUDA) message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") @@ -1445,9 +1442,6 @@ if (onnxruntime_USE_CUDA) file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME}) endif() find_package(CUDAToolkit REQUIRED) - if(onnxruntime_CUDNN_HOME) - file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME) - endif() if (NOT CMAKE_CUDA_ARCHITECTURES) if (CMAKE_LIBRARY_ARCHITECTURE STREQUAL "aarch64-linux-gnu") # Support for Jetson/Tegra ARM devices diff --git a/cmake/deps.txt b/cmake/deps.txt index 9d206b6bb3ae..56d77d9e7002 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -37,8 +37,8 @@ mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/v0.3.zip;5ec64e3071edc7347ebd8a81679cf06e2bb9b851 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.16.1.zip;2eb9198bb352757d5ff13977cbe0634898e0837c -#use the latest commit of 10.0-GA -onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/06adf4461ac84035bee658c6cf5df39f7ab6071d.zip;46dceef659d75d276e7914a8057c2282269d5e7b +#use the latest commit of 10.2-GA +onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/f161f95883b4ebd8cb789de5efc67b73c0a6e694.zip;2148d0c79a171abf2b9451f3bfec164e85caf2ef protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874 @@ -58,3 +58,4 @@ utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e +cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.2.zip;11071a47594b20f00af09aad83e0d5203ccf6029 diff --git a/cmake/external/cuDNN.cmake b/cmake/external/cuDNN.cmake new file mode 100644 index 000000000000..3d05f6406a80 --- /dev/null +++ b/cmake/external/cuDNN.cmake @@ -0,0 +1,111 @@ +add_library(CUDNN::cudnn_all INTERFACE IMPORTED) + +find_path( + CUDNN_INCLUDE_DIR cudnn.h + HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_INCLUDE_DIRS} + PATH_SUFFIXES include + REQUIRED +) + +file(READ "${CUDNN_INCLUDE_DIR}/cudnn_version.h" cudnn_version_header) +string(REGEX MATCH "#define CUDNN_MAJOR [1-9]+" macrodef "${cudnn_version_header}") +string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}") + +function(find_cudnn_library NAME) + find_library( + ${NAME}_LIBRARY ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}" + HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_LIBRARY_DIR} + PATH_SUFFIXES lib64 lib/x64 lib + REQUIRED + ) + + if(${NAME}_LIBRARY) + add_library(CUDNN::${NAME} UNKNOWN IMPORTED) + set_target_properties( + CUDNN::${NAME} PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR} + IMPORTED_LOCATION ${${NAME}_LIBRARY} + ) + message(STATUS "${NAME} found at ${${NAME}_LIBRARY}.") + else() + message(STATUS "${NAME} not found.") + endif() + + +endfunction() + +find_cudnn_library(cudnn) + +include (FindPackageHandleStandardArgs) +find_package_handle_standard_args( + LIBRARY REQUIRED_VARS + CUDNN_INCLUDE_DIR cudnn_LIBRARY +) + +if(CUDNN_INCLUDE_DIR AND cudnn_LIBRARY) + + message(STATUS "cuDNN: ${cudnn_LIBRARY}") + message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}") + + set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found") + +else() + + set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Not Found") + +endif() + +target_include_directories( + CUDNN::cudnn_all + INTERFACE + $ + $ +) + +target_link_libraries( + CUDNN::cudnn_all + INTERFACE + CUDNN::cudnn +) + +if(CUDNN_MAJOR_VERSION EQUAL 8) + find_cudnn_library(cudnn_adv_infer) + find_cudnn_library(cudnn_adv_train) + find_cudnn_library(cudnn_cnn_infer) + find_cudnn_library(cudnn_cnn_train) + find_cudnn_library(cudnn_ops_infer) + find_cudnn_library(cudnn_ops_train) + + target_link_libraries( + CUDNN::cudnn_all + INTERFACE + CUDNN::cudnn_adv_train + CUDNN::cudnn_ops_train + CUDNN::cudnn_cnn_train + CUDNN::cudnn_adv_infer + CUDNN::cudnn_cnn_infer + CUDNN::cudnn_ops_infer + ) +elseif(CUDNN_MAJOR_VERSION EQUAL 9) + find_cudnn_library(cudnn_cnn) + find_cudnn_library(cudnn_adv) + find_cudnn_library(cudnn_graph) + find_cudnn_library(cudnn_ops) + find_cudnn_library(cudnn_engines_runtime_compiled) + find_cudnn_library(cudnn_engines_precompiled) + find_cudnn_library(cudnn_heuristic) + + target_link_libraries( + CUDNN::cudnn_all + INTERFACE + CUDNN::cudnn_adv + CUDNN::cudnn_ops + CUDNN::cudnn_cnn + CUDNN::cudnn_graph + CUDNN::cudnn_engines_runtime_compiled + CUDNN::cudnn_engines_precompiled + CUDNN::cudnn_heuristic + ) +endif() + +mark_as_advanced(CUDNN_INCLUDE_DIR) diff --git a/cmake/external/cudnn_frontend.cmake b/cmake/external/cudnn_frontend.cmake new file mode 100644 index 000000000000..7ac51deaf93b --- /dev/null +++ b/cmake/external/cudnn_frontend.cmake @@ -0,0 +1,12 @@ +include(FetchContent) +FetchContent_Declare( + cudnn_frontend + URL ${DEP_URL_cudnn_frontend} + URL_HASH SHA1=${DEP_SHA1_cudnn_frontend} +) + +set(CUDNN_FRONTEND_BUILD_SAMPLES OFF) +set(CUDNN_FRONTEND_BUILD_UNIT_TESTS OFF) +set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF) +set(CUDNN_PATH ${onnxruntime_CUDNN_HOME}) +FetchContent_MakeAvailable(cudnn_frontend) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 14e6ed515fd6..4e5270747405 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -46,9 +46,6 @@ if (onnxruntime_BUILD_UNIT_TESTS) if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set(gtest_disable_pthreads ON) endif() - if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") - set(gtest_disable_pthreads ON CACHE BOOL "gtest_disable_pthreads" FORCE) - endif() set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) if (IOS OR ANDROID) # on mobile platforms the absl flags class dumps the flag names (assumably for binary size), which breaks passing @@ -590,20 +587,16 @@ endif() message("Finished fetching external dependencies") - set(onnxruntime_LINK_DIRS ) + if (onnxruntime_USE_CUDA) - #TODO: combine onnxruntime_CUDNN_HOME and onnxruntime_CUDA_HOME, assume they are the same find_package(CUDAToolkit REQUIRED) - if (WIN32) - if(onnxruntime_CUDNN_HOME) - list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib/x64) - endif() - else() - if(onnxruntime_CUDNN_HOME) - list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib64) - endif() + + if(onnxruntime_CUDNN_HOME) + file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME) + set(CUDNN_PATH ${onnxruntime_CUDNN_HOME}) endif() + include(cuDNN) endif() if(onnxruntime_USE_SNPE) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index bdb4b00b02a3..927b4ac84b03 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -38,10 +38,14 @@ function(get_c_cxx_api_headers HEADERS_VAR) # need to add header files for enabled EPs foreach(f ${ONNXRUNTIME_PROVIDER_NAMES}) - file(GLOB _provider_headers CONFIGURE_DEPENDS - "${REPO_ROOT}/include/onnxruntime/core/providers/${f}/*.h" - ) - list(APPEND _headers ${_provider_headers}) + # The header files in include/onnxruntime/core/providers/cuda directory cannot be flattened to the same directory + # with onnxruntime_c_api.h . Most other EPs probably also do not work in this way. + if((NOT f STREQUAL cuda) AND (NOT f STREQUAL rocm)) + file(GLOB _provider_headers CONFIGURE_DEPENDS + "${REPO_ROOT}/include/onnxruntime/core/providers/${f}/*.h" + ) + list(APPEND _headers ${_provider_headers}) + endif() endforeach() set(${HEADERS_VAR} ${_headers} PARENT_SCOPE) diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index b85edbf37d44..9f8d807fad8f 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -69,7 +69,7 @@ endif() if(onnxruntime_USE_TENSORRT OR onnxruntime_USE_NCCL) # TODO: for now, core framework depends on CUDA. It should be moved to TensorRT EP # TODO: provider_bridge_ort.cc should not include nccl.h -target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${onnxruntime_CUDNN_HOME}/include PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) else() target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) endif() diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 66f4aea606ef..c02ac2096db2 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -555,8 +555,17 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ) - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") +message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") +message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") + +if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "10") + message(STATUS "Using -mavx2 -mfma -mavxvnni flags") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") +else() + message(STATUS "Using -mavx2 -mfma flags") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") +endif() set(mlas_platform_srcs_avx512f ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S @@ -575,7 +584,7 @@ else() ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp ) - set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl") + set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl") set(mlas_platform_srcs_avx512vnni ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index d2afe19f3669..bbcc709b144a 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -219,6 +219,7 @@ if (onnxruntime_ENABLE_TRAINING) endif() install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/cpu/cpu_provider_factory.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) +install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/resource.h ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/custom_op_context.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers) set_target_properties(onnxruntime_providers PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 82c31ce6b6b4..774b7a4f6bd7 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -197,12 +197,16 @@ target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL) target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface CUDA::cudart) else() - target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas cudnn CUDA::curand CUDA::cufft CUDA::cudart - ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) - if(onnxruntime_CUDNN_HOME) - target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) - target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) + include(cudnn_frontend) # also defines CUDNN::* + if (onnxruntime_USE_CUDA_NHWC_OPS) + if(CUDNN_MAJOR_VERSION GREATER 8) + add_compile_definitions(ENABLE_CUDA_NHWC_OPS) + else() + message( WARNING "To compile with NHWC ops enabled please compile against cuDNN 9 or newer." ) + endif() endif() + target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas CUDNN::cudnn_all cudnn_frontend CUDA::curand CUDA::cufft CUDA::cudart + ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) endif() if (onnxruntime_USE_TRITON_KERNEL) @@ -289,8 +293,15 @@ config_cuda_provider_shared_module(onnxruntime_providers_cuda_obj) endif() config_cuda_provider_shared_module(onnxruntime_providers_cuda) - + # Cannot use glob because the file cuda_provider_options.h should not be exposed out. + set(ONNXRUNTIME_CUDA_PROVIDER_PUBLIC_HEADERS + "${REPO_ROOT}/include/onnxruntime/core/providers/cuda/cuda_context.h" + "${REPO_ROOT}/include/onnxruntime/core/providers/cuda/cuda_resource.h" + ) + set_target_properties(onnxruntime_providers_cuda PROPERTIES + PUBLIC_HEADER "${ONNXRUNTIME_CUDA_PROVIDER_PUBLIC_HEADERS}") install(TARGETS onnxruntime_providers_cuda + PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers/cuda ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index d738e29101cf..5d1a481d40ab 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -17,8 +17,8 @@ # Header paths find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX) - if(OpenVINO_VERSION VERSION_LESS 2023.0) - message(FATAL_ERROR "OpenVINO 2023.0 and newer are supported. Please, latest OpenVINO release") + if(OpenVINO_VERSION VERSION_LESS 2024.0) + message(FATAL_ERROR "OpenVINO 2024.0 and newer are supported. Please, use latest OpenVINO release") endif() if (WIN32) diff --git a/cmake/onnxruntime_providers_rocm.cmake b/cmake/onnxruntime_providers_rocm.cmake index 71692ddb9391..559204bd0df8 100644 --- a/cmake/onnxruntime_providers_rocm.cmake +++ b/cmake/onnxruntime_providers_rocm.cmake @@ -223,8 +223,13 @@ if (onnxruntime_ENABLE_ATEN) target_compile_definitions(onnxruntime_providers_rocm PRIVATE ENABLE_ATEN) endif() - + file(GLOB ONNXRUNTIME_ROCM_PROVIDER_PUBLIC_HEADERS CONFIGURE_DEPENDS + "${REPO_ROOT}/include/onnxruntime/core/providers/rocm/*.h" + ) + set_target_properties(onnxruntime_providers_rocm PROPERTIES + PUBLIC_HEADER "${ONNXRUNTIME_ROCM_PROVIDER_PUBLIC_HEADERS}") install(TARGETS onnxruntime_providers_rocm + PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers/rocm ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 3d46c139feea..468aaa44ec4e 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -159,7 +159,7 @@ if(onnxruntime_CUDA_MINIMAL) set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) else() - set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) + set(trt_link_libs CUDNN::cudnn_all cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) endif() file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.h" @@ -183,9 +183,6 @@ endif() target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) - if(onnxruntime_CUDNN_HOME) - target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include) - endif() # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found set_target_properties(onnxruntime_providers_tensorrt PROPERTIES LINKER_LANGUAGE CUDA) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 270139ceaff7..b2dbe4b3da5e 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -98,11 +98,7 @@ endif() onnxruntime_add_include_to_target(onnxruntime_pybind11_state Python::Module Python::NumPy) target_include_directories(onnxruntime_pybind11_state PRIVATE ${ONNXRUNTIME_ROOT} ${pybind11_INCLUDE_DIRS}) if(onnxruntime_USE_CUDA) - target_include_directories(onnxruntime_pybind11_state PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - # cudnn_home is optional for Window when cuda and cudnn are installed in the same directory. - if(onnxruntime_CUDNN_HOME) - target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_CUDNN_HOME}/include) - endif() + target_include_directories(onnxruntime_pybind11_state PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR}) endif() if(onnxruntime_USE_CANN) target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_CANN_HOME}/include) @@ -512,7 +508,6 @@ file(GLOB onnxruntime_ort_format_model_srcs CONFIGURE_DEPENDS ) file(GLOB onnxruntime_mobile_helpers_srcs CONFIGURE_DEPENDS ${REPO_ROOT}/tools/python/util/mobile_helpers/*.py - ${REPO_ROOT}/tools/ci_build/github/android/mobile_package.required_operators.config ${REPO_ROOT}/tools/ci_build/github/android/nnapi_supported_ops.md ${REPO_ROOT}/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md ${REPO_ROOT}/tools/ci_build/github/apple/coreml_supported_neuralnetwork_ops.md diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index a8c876d30873..1740144cf655 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -145,6 +145,7 @@ set(provider_excluded_files "rnn/rnn_impl.cu" "rnn/rnn_impl.h" "shared_inc/cuda_call.h" + "shared_inc/cudnn_fe_call.h" "shared_inc/fpgeneric.h" "cuda_allocator.cc" "cuda_allocator.h" @@ -171,6 +172,7 @@ set(provider_excluded_files "cuda_utils.cu" "cudnn_common.cc" "cudnn_common.h" + "cudnn_fe_call.cc" "cupti_manager.cc" "cupti_manager.h" "fpgeneric.cu" diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 79bee3bdb65d..b51c87595113 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -44,9 +44,7 @@ if (onnxruntime_USE_EXTENSIONS) endif() add_dependencies(onnxruntime_session ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime") -if (onnxruntime_USE_CUDA) - target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) -endif() + if (onnxruntime_USE_ROCM) target_compile_options(onnxruntime_session PRIVATE -Wno-sign-compare -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1) target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/hipcub/include ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include) diff --git a/cmake/onnxruntime_training.cmake b/cmake/onnxruntime_training.cmake index 01590a431205..b633a9c2de37 100644 --- a/cmake/onnxruntime_training.cmake +++ b/cmake/onnxruntime_training.cmake @@ -39,10 +39,6 @@ endif() target_include_directories(onnxruntime_training PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${ORTTRAINING_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${onnxruntime_graph_header} ${MPI_CXX_INCLUDE_DIRS}) -if (onnxruntime_USE_CUDA) - target_include_directories(onnxruntime_training PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) -endif() - if (onnxruntime_USE_NCCL) target_include_directories(onnxruntime_training PRIVATE ${NCCL_INCLUDE_DIRS}) endif() @@ -81,9 +77,6 @@ if (onnxruntime_BUILD_UNIT_TESTS) target_include_directories(onnxruntime_training_runner PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${ORTTRAINING_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${onnxruntime_graph_header}) target_link_libraries(onnxruntime_training_runner PRIVATE nlohmann_json::nlohmann_json) - if (onnxruntime_USE_CUDA) - target_include_directories(onnxruntime_training_runner PUBLIC ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - endif() if (onnxruntime_USE_NCCL) target_include_directories(onnxruntime_training_runner PRIVATE ${NCCL_INCLUDE_DIRS}) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 0c1e5e93c684..d5c3af748e52 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -67,7 +67,7 @@ function(AddTest) if(onnxruntime_USE_CUDA) #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs, # otherwise it will impact when CUDA DLLs can be unloaded. - target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart) + target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart cudnn_frontend) endif() target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES}) endif() @@ -75,10 +75,13 @@ function(AddTest) onnxruntime_add_include_to_target(${_UT_TARGET} date::date flatbuffers::flatbuffers) target_include_directories(${_UT_TARGET} PRIVATE ${TEST_INC_DIR}) if (onnxruntime_USE_CUDA) - target_include_directories(${_UT_TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include) + target_include_directories(${_UT_TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR}) if (onnxruntime_USE_NCCL) target_include_directories(${_UT_TARGET} PRIVATE ${NCCL_INCLUDE_DIRS}) endif() + if(onnxruntime_CUDA_MINIMAL) + target_compile_definitions(${_UT_TARGET} PRIVATE -DUSE_CUDA_MINIMAL) + endif() endif() if (onnxruntime_USE_TENSORRT) # used for instantiating placeholder TRT builder to mitigate TRT library load/unload overhead @@ -392,7 +395,7 @@ if (onnxruntime_USE_CUDA AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_R ) list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cuda_src}) - if (onnxruntime_USE_CUDA_NHWC_OPS) + if (onnxruntime_USE_CUDA_NHWC_OPS AND CUDNN_MAJOR_VERSION GREATER 8) file(GLOB onnxruntime_test_providers_cuda_nhwc_src CONFIGURE_DEPENDS "${TEST_SRC_DIR}/providers/cuda/nhwc/*.cc" ) @@ -1498,7 +1501,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") list(APPEND custom_op_src_patterns "${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu" "${TEST_SRC_DIR}/testdata/custom_op_library/cuda/cuda_ops.*") - list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include) + list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR}) if (HAS_QSPECTRE) list(APPEND custom_op_lib_option "$<$:SHELL:--compiler-options /Qspectre>") endif() diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 7a49e90c00bc..0686b66876d9 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -225,7 +225,7 @@ else() "SHELL:-s EXPORT_ALL=0" "SHELL:-s VERBOSE=0" "SHELL:-s FILESYSTEM=0" - "SHELL:-s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm,mainScriptUrlOrBlob]" + "SHELL:-s INCOMING_MODULE_JS_API=[locateFile,instantiateWasm,wasmBinary]" "SHELL:-s WASM_BIGINT=1" ${WASM_API_EXCEPTION_CATCHING} --no-entry diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index 162d33581a5c..6ac3555eeecf 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -86,3 +86,386 @@ index 0aab3e26..398ac2d6 100644 +#endif + #endif // ! ONNX_ONNX_PB_H +diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc +index c315a2a7..58963154 100644 +--- a/onnx/defs/math/defs.cc ++++ b/onnx/defs/math/defs.cc +@@ -3472,6 +3472,9 @@ ONNX_OPERATOR_SET_SCHEMA( + } + + auto& input_shape = getInputShape(ctx, 0); ++ if (input_shape.dim_size() < 2) { ++ fail_shape_inference("First input should have at least 2 dimensions in ", ctx.getDisplayName(), "."); ++ } + auto signal_dim = input_shape.dim(1); + if (!signal_dim.has_dim_value()) { + return; +diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc +index be6a851d..fad595d0 100644 +--- a/onnx/defs/nn/defs.cc ++++ b/onnx/defs/nn/defs.cc +@@ -126,6 +126,9 @@ void convPoolShapeInference( + residual -= stride; + } + } ++ if (i >= static_cast(effective_kernel_shape.size())) { ++ fail_shape_inference("kernel shape should have ", input_dims_size, " values in ", ctx.getDisplayName(), "."); ++ } + int64_t total_pad = residual == 0 ? effective_kernel_shape[i] - stride : effective_kernel_shape[i] - residual; + if (total_pad < 0) + total_pad = 0; +@@ -959,19 +962,21 @@ ONNX_OPERATOR_SET_SCHEMA( + auto w_type = ctx.getInputType(3); + if (nullptr == x_type || nullptr == w_type || x_type->value_case() != TypeProto::kTensorType || + w_type->value_case() != TypeProto::kTensorType) { +- fail_type_inference("inputs are expected to have tensor type."); ++ fail_type_inference("inputs are expected to have tensor type in ", ctx.getDisplayName(), "."); + } + + auto x_zero_point_type = ctx.getInputType(2); + if (nullptr == x_zero_point_type || + x_zero_point_type->tensor_type().elem_type() != x_type->tensor_type().elem_type()) { +- fail_type_inference("input and zero_point pair is expected to have be same type."); ++ fail_type_inference( ++ "input and zero_point pair is expected to have be same type in ", ctx.getDisplayName(), "."); + } + + auto w_zero_point_type = ctx.getInputType(5); + if (nullptr == w_zero_point_type || + w_zero_point_type->tensor_type().elem_type() != w_type->tensor_type().elem_type()) { +- fail_type_inference("weight and zero_point pair is expected to have same type."); ++ fail_type_inference( ++ "weight and zero_point pair is expected to have same type in ", ctx.getDisplayName(), "."); + } + + propagateElemTypeFromInputToOutput(ctx, 7, 0); +@@ -2647,7 +2652,8 @@ ONNX_OPERATOR_SET_SCHEMA( + if (!hasNInputShapes(ctx, 1)) { + return; + } +- auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); ++ ++ auto& input_shape = getInputShape(ctx, 0); + int64_t input_ndim = input_shape.dim_size(); + int64_t axis = -1; + auto axis_proto = ctx.getAttribute("axis"); +@@ -2659,7 +2665,16 @@ ONNX_OPERATOR_SET_SCHEMA( + // positive value. + axis += input_ndim; + } +- ++ if (axis < 0) { ++ fail_shape_inference( ++ "Unexpected axis value (", ++ axis, ++ ") rank of first input is ", ++ input_ndim, ++ " in ", ++ ctx.getDisplayName(), ++ "."); ++ } + if (ctx.getNumOutputs() > 1) { + auto mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape(); + mean_shape->CopyFrom(input_shape); +diff --git a/onnx/defs/nn/old.cc b/onnx/defs/nn/old.cc +index 57f8e2a4..8b2dc07f 100644 +--- a/onnx/defs/nn/old.cc ++++ b/onnx/defs/nn/old.cc +@@ -201,6 +201,9 @@ void convPoolShapeInference_opset19( + residual -= stride; + } + } ++ if (i >= static_cast(effective_kernel_shape.size())) { ++ fail_shape_inference("kernel shape should have ", input_dims_size, " values in ", ctx.getDisplayName(), "."); ++ } + int64_t total_pad = residual == 0 ? effective_kernel_shape[i] - stride : effective_kernel_shape[i] - residual; + if (total_pad < 0) + total_pad = 0; +diff --git a/onnx/defs/shape_inference.h b/onnx/defs/shape_inference.h +index a80473b3..d1bcd401 100644 +--- a/onnx/defs/shape_inference.h ++++ b/onnx/defs/shape_inference.h +@@ -105,6 +105,10 @@ struct InferenceContext { + virtual const SparseTensorProto* getInputSparseData(size_t index) const = 0; + // Gets the shape inputs computed by partial data propagation. + virtual const TensorShapeProto* getSymbolicInput(size_t index) const = 0; ++ // To display a name the user can use to narrow its search. ++ virtual std::string getDisplayName() const { ++ return ""; ++ } + }; + + // We use data propagation to perform partial evaluation of the model, to compute statically +@@ -263,7 +267,15 @@ inline void propagateElemTypeFromDtypeToOutput( + } else { + // This is not expected to happen + fail_type_inference( +- "Output ", outputIndex, " expected to have: ", expected_value_case, " or UNDEFINED. Got: ", output_value_case); ++ "Output ", ++ outputIndex, ++ " expected to have: ", ++ expected_value_case, ++ " or UNDEFINED. Got: ", ++ output_value_case, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -277,18 +289,18 @@ inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const Attr + const auto attr_type = attr->type(); + if (attr_type == AttributeProto::TENSOR) { + if (attr->t().dims().size() != 1) { +- fail_type_inference("Attribute expected to have a one-dim tensor"); ++ fail_type_inference("Attribute expected to have a one-dim tensor in ", ctx.getDisplayName(), "."); + } + data_type = attr->t().data_type(); + expected_value_case = TypeProto::kTensorType; + } else if (attr_type == AttributeProto::SPARSE_TENSOR) { + if (attr->sparse_tensor().dims().size() != 1) { +- fail_type_inference("Attribute expected to have a one-dim sparse tensor"); ++ fail_type_inference("Attribute expected to have a one-dim sparse tensor in ", ctx.getDisplayName(), "."); + } + data_type = attr->sparse_tensor().values().data_type(); + expected_value_case = TypeProto::kSparseTensorType; + } else { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Attribute expected to have tensor or sparse tensor type in ", ctx.getDisplayName(), "."); + } + + propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, expected_value_case); +@@ -326,7 +338,10 @@ inline const TensorShapeProto& getInputShape(const InferenceContext& ctx, size_t + const auto* input_type = ctx.getInputType(n); + const auto value_case = input_type->value_case(); + if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), "."); ++ } ++ if (!hasShape(*input_type)) { ++ fail_shape_inference("Input ", n, " must have a non null shape in ", ctx.getDisplayName(), "."); + } + if (value_case == TypeProto::kTensorType) { + return input_type->tensor_type().shape(); +@@ -344,7 +359,7 @@ inline const TensorShapeProto* getOptionalInputShape(InferenceContext& ctx, size + + const auto value_case = input_type->value_case(); + if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), "."); + } + if (value_case == TypeProto::kTensorType) { + return &input_type->tensor_type().shape(); +@@ -372,7 +387,10 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType( + " does not match type of output: ", + outputIndex, + "type: ", +- output_value_case); ++ output_value_case, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + if (TypeProto::kTensorType == input_value_case) { + auto* dim = output_type->mutable_tensor_type()->mutable_shape()->add_dim(); +@@ -382,7 +400,13 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType( + *dim = input_type->sparse_tensor_type().shape().dim(static_cast(fromDimIndex)); + } else { + fail_type_inference( +- "Input ", inputIndex, " and Output ", outputIndex, " expected to have tensor or sparse tensor type"); ++ "Input ", ++ inputIndex, ++ " and Output ", ++ outputIndex, ++ " expected to have tensor or sparse tensor type in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -440,7 +464,14 @@ updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType + setTensorElementType(elemType, expected_type, *output_type); + } else { + // This is not expected to happen +- fail_type_inference("Output ", outputIndex, " expected to have tensor or sparse tensor type: ", expected_type); ++ fail_type_inference( ++ "Output ", ++ outputIndex, ++ " expected to have tensor or sparse tensor type: ", ++ expected_type, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -462,16 +493,17 @@ inline void propagateElemTypeFromAttributeToOutput( + updateOutputElemType(ctx, outputIndex, default_value, expected_type); + return; + } else { +- fail_type_inference("Value of attribute ", attributeName, " not specified"); ++ fail_type_inference("Value of attribute ", attributeName, " not specified in ", ctx.getDisplayName(), "."); + } + } + if (!attr_proto->has_i()) { +- fail_type_inference("Attribute ", attributeName, " should be of integer type and specify a type."); ++ fail_type_inference( ++ "Attribute ", attributeName, " should be of integer type and specify a type in ", ctx.getDisplayName(), "."); + } + auto attr_value = attr_proto->i(); + auto elem_type = static_cast(attr_value); + if (!TensorProto_DataType_IsValid(elem_type)) { +- fail_type_inference("Attribute ", attributeName, " does not specify a valid type."); ++ fail_type_inference("Attribute ", attributeName, " does not specify a valid type in ", ctx.getDisplayName(), "."); + } + updateOutputElemType(ctx, outputIndex, elem_type, expected_type); + } +@@ -497,7 +529,7 @@ inline TensorShapeProto* + getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType) { + auto output_type = ctx.getOutputType(n); + if (output_type == nullptr) { +- fail_type_inference("Output ", n, " expected to have tensor or sparse type"); ++ fail_type_inference("Output ", n, " expected to have tensor or sparse type in ", ctx.getDisplayName(), "."); + } + const auto output_value_case = output_type->value_case(); + if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) { +@@ -505,7 +537,7 @@ getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_typ + } else if (output_value_case == TypeProto::VALUE_NOT_SET) { + return getTensorMutableShape(default_type, *output_type); + } else { +- fail_type_inference("Output ", n, " expected to have tensor type"); ++ fail_type_inference("Output ", n, " expected to have tensor type in ", ctx.getDisplayName(), "."); + } + } + +@@ -562,13 +594,13 @@ inline void propagateShapeFromAttributeToOutput( + auto attr_proto = ctx.getAttribute(attributeName); + if ((nullptr == attr_proto) || (!attr_proto->has_type()) || + (attr_proto->type() != AttributeProto_AttributeType_INTS)) { +- fail_shape_inference("Attribute ", attributeName, " should specify a shape"); ++ fail_shape_inference("Attribute ", attributeName, " should specify a shape in ", ctx.getDisplayName(), "."); + } + auto& int_list = attr_proto->ints(); + TensorShapeProto shape; + for (auto dim_size : int_list) { + if (dim_size < 0) { +- fail_shape_inference("Negative values are not allowed in a shape specification"); ++ fail_shape_inference("Negative values are not allowed in a shape specification in ", ctx.getDisplayName(), "."); + } + shape.add_dim()->set_dim_value(dim_size); + } +@@ -745,7 +777,16 @@ inline void checkInputRank(InferenceContext& ctx, size_t input_index, int expect + if (hasInputShape(ctx, input_index)) { + auto rank = getInputShape(ctx, input_index).dim_size(); + if (rank != expected_rank) { +- fail_shape_inference("Input ", input_index, " expected to have rank ", expected_rank, " but has rank ", rank); ++ fail_shape_inference( ++ "Input ", ++ input_index, ++ " expected to have rank ", ++ expected_rank, ++ " but has rank ", ++ rank, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + } +@@ -798,7 +839,15 @@ inline void unifyInputDim(InferenceContext& ctx, size_t input_index, int dim_ind + // This shape is expected to have rank > dim_index: + if (input_shape.dim_size() <= dim_index) { + fail_shape_inference( +- "Input ", input_index, " expected to have rank >", dim_index, " but has rank ", input_shape.dim_size()); ++ "Input ", ++ input_index, ++ " expected to have rank >", ++ dim_index, ++ " but has rank ", ++ input_shape.dim_size(), ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + const Dim& input_dim = input_shape.dim(dim_index); + // Now, unify dim and input_dim: +diff --git a/onnx/shape_inference/implementation.cc b/onnx/shape_inference/implementation.cc +index 8723dcd4..8249fc59 100644 +--- a/onnx/shape_inference/implementation.cc ++++ b/onnx/shape_inference/implementation.cc +@@ -906,7 +906,7 @@ struct FunctionInferenceContext : public InferenceContext { + const std::vector& input_types, + const std::vector& attributes, + const ShapeInferenceOptions& options) +- : input_types_(input_types), options_(options) { ++ : input_types_(input_types), options_(options), func_proto_(&func_proto) { + for (const auto& attr : attributes) { + attributesByName_[attr.name()] = &attr; + } +@@ -971,11 +971,25 @@ struct FunctionInferenceContext : public InferenceContext { + return std::move(output_types_); + } + ++ std::string getDisplayName() const override { ++ if (func_proto_ == nullptr) ++ return ""; ++ if (func_proto_->domain().empty()) { ++ if (func_proto_->name().empty()) ++ return ""; ++ return MakeString("function ", func_proto_->name()); ++ } ++ if (func_proto_->name().empty()) ++ return MakeString("function [", func_proto_->domain(), "]"); ++ return MakeString("function ", func_proto_->name(), "[", func_proto_->domain(), "]"); ++ } ++ + private: + const std::vector& input_types_; + std::vector output_types_; + std::unordered_map attributesByName_; + ShapeInferenceOptions options_; ++ const FunctionProto* func_proto_; + }; + + std::vector InferFunctionOutputTypes( +diff --git a/onnx/shape_inference/implementation.h b/onnx/shape_inference/implementation.h +index 2c63c910..b0e4c32d 100644 +--- a/onnx/shape_inference/implementation.h ++++ b/onnx/shape_inference/implementation.h +@@ -146,7 +146,7 @@ struct InferenceContextImpl : public InferenceContext { + const ShapeInferenceOptions& options, + DataValueMap* generatedShapeData = nullptr, + GraphInferenceContext* graphInferenceContext = nullptr) +- : graphInferenceContext_{graphInferenceContext}, options_(options) { ++ : graphInferenceContext_{graphInferenceContext}, options_(options), node_(&n) { + for (auto& attr : *n.mutable_attribute()) { + attributesByName_[attr.name()] = &attr; + if (attr.has_g()) { +@@ -277,6 +277,19 @@ struct InferenceContextImpl : public InferenceContext { + return inferencer; + } + ++ std::string getDisplayName() const override { ++ if (node_ == nullptr) ++ return ""; ++ if (node_->domain().empty()) { ++ if (node_->name().empty()) ++ return MakeString("node ", node_->op_type()); ++ return MakeString("node ", node_->op_type(), " (", node_->name(), ")"); ++ } ++ if (node_->name().empty()) ++ return MakeString("node ", node_->op_type(), "[", node_->domain(), "]"); ++ return MakeString("node ", node_->op_type(), "[", node_->domain(), "]", " (", node_->name(), ")"); ++ } ++ + std::vector allInputData_; + std::vector allInputSparseData_; + std::vector allShapeInputData_; +@@ -289,6 +302,7 @@ struct InferenceContextImpl : public InferenceContext { + // mutable as internal cache of GraphInferencer instances + mutable std::unordered_map> graphAttributeInferencers_; + ShapeInferenceOptions options_; ++ NodeProto* node_; + }; + + struct DataPropagationContextImpl : public DataPropagationContext { diff --git a/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj b/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj index 647c0bbe6a24..29fc9f3bc382 100644 --- a/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj +++ b/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj @@ -8,7 +8,7 @@ - + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj index 3c8a49bf9357..deb6b4f884bc 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj @@ -23,7 +23,7 @@ + '$(IncludeMobileTargets)' == 'true'"> net8.0-android @@ -31,6 +31,43 @@ $(BaseTargets);$(MobileTargets) + + Microsoft.ML.OnnxRuntime + Microsoft.ML.OnnxRuntime + + + + 1.0.0 + 0.0.0 + + + + true + Microsoft.ML.OnnxRuntime C# Bindings + Microsoft + © Microsoft Corporation. All rights reserved. + This package contains ONNX Runtime for .Net platforms + + + $(PackageVersion) + + + + + Microsoft + Microsoft.ML.OnnxRuntime.Managed + ONNX;ONNX Runtime;Machine Learning + https://github.com/Microsoft/onnxruntime + LICENSE.txt + ORT_icon_for_light_bg.png + + Release Def: + Branch: $(BUILD_SOURCEBRANCH) + Commit: $(BUILD_SOURCEVERSION) + Build: https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=$(BUILD_BUILDID) + + + AnyCPU;x86 default @@ -43,8 +80,6 @@ $(OnnxRuntimeRoot)\csharp x64 - Microsoft.ML.OnnxRuntime - Microsoft.ML.OnnxRuntime false false portable @@ -54,27 +89,8 @@ on their device is not built for training, an exception will be thrown with the following message - "Training is disabled in the current build. Please build onnxruntime from source with the build flags enable_training_apis. "--> - true + true - - - Microsoft.ML.OnnxRuntime.Managed - Microsoft - 1.0.0 - 0.0.0 - $(PackageVersion) - This package contains ONNX Runtime for .Net platforms - ONNX;ONNX Runtime;Machine Learning - https://github.com/Microsoft/onnxruntime - © Microsoft Corporation. All rights reserved. - LICENSE.txt - ORT_icon_for_light_bg.png - - Release Def: - Branch: $(BUILD_SOURCEBRANCH) - Commit: $(BUILD_SOURCEVERSION) - Build: https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=$(BUILD_BUILDID) - true @@ -82,7 +98,6 @@ false - false $(AllowedOutputExtensionsInPackageBuildOutputFolder);.pdb Debug;Release;RelWithDebInfo @@ -158,10 +173,6 @@ $(OrtConstants);__ENABLE_COREML__ - - $(OrtConstants);__XAMARIN__ - - $(DefineConstants);$(OrtConstants) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 1ba5f14641e7..9b1df9357dc8 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -76,7 +76,7 @@ static NativeTrainingMethods() DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi)); #endif - const uint ORT_API_VERSION = 19; + const uint ORT_API_VERSION = 20; #if NETSTANDARD2_0 IntPtr ortApiPtr = OrtGetApi(ORT_API_VERSION); api_ = (OrtApi)Marshal.PtrToStructure(ortApiPtr, typeof(OrtApi)); diff --git a/dockerfiles/Dockerfile.openvino b/dockerfiles/Dockerfile.openvino index 75898770acf2..39e75a68a369 100644 --- a/dockerfiles/Dockerfile.openvino +++ b/dockerfiles/Dockerfile.openvino @@ -3,11 +3,11 @@ # SPDX-License-Identifier: MIT #-------------------------------------------------------------------------- -ARG OPENVINO_VERSION=2024.0.0 +ARG OPENVINO_VERSION=2024.2.0 # Build stage -FROM openvino/ubuntu20_runtime:${OPENVINO_VERSION} AS builder +FROM openvino/ubuntu22_runtime:${OPENVINO_VERSION} AS builder ENV WORKDIR_PATH=/home/openvino WORKDIR $WORKDIR_PATH @@ -34,20 +34,18 @@ RUN cat /etc/apt/sources.list | sed 's/^# deb-src/deb-src/g' > ./temp; mv temp / RUN apt update; apt install dpkg-dev RUN mkdir /sources WORKDIR /sources -RUN apt-get source cron iso-codes lsb-release powermgmt-base python-apt-common python3-apt python3-dbus python3-gi unattended-upgrades libapt-pkg6.0 libhogweed5 libnettle7 +RUN apt-get source cron iso-codes lsb-release powermgmt-base python-apt-common python3-apt python3-dbus python3-gi libapt-pkg6.0 libhogweed6 libnettle8 WORKDIR / RUN tar cvf GPL_sources.tar.gz /sources # Deploy stage -FROM openvino/ubuntu20_runtime:${OPENVINO_VERSION} +FROM openvino/ubuntu22_runtime:${OPENVINO_VERSION} ENV DEBIAN_FRONTEND noninteractive USER root COPY --from=builder /home/openvino/onnxruntime/build/Linux/Release/dist/*.whl ./ COPY --from=builder /GPL_sources.tar.gz ./ RUN python3 -m pip install ./*.whl && rm ./*.whl -RUN apt update; apt install -y unattended-upgrades && \ - unattended-upgrade ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER diff --git a/dockerfiles/scripts/install_common_deps.sh b/dockerfiles/scripts/install_common_deps.sh index 786a6f076a71..41bdc068d8cd 100644 --- a/dockerfiles/scripts/install_common_deps.sh +++ b/dockerfiles/scripts/install_common_deps.sh @@ -21,6 +21,6 @@ pip install "wheel>=0.35.1" rm -rf /opt/miniconda/pkgs # Dependencies: cmake -wget --quiet https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3-linux-x86_64.tar.gz -tar zxf cmake-3.27.3-linux-x86_64.tar.gz -rm -rf cmake-3.27.3-linux-x86_64.tar.gz +wget --quiet https://github.com/Kitware/CMake/releases/download/v3.30.1/cmake-3.30.1-linux-x86_64.tar.gz +tar zxf cmake-3.30.1-linux-x86_64.tar.gz +rm -rf cmake-3.30.1-linux-x86_64.tar.gz diff --git a/docs/ORTMobilePackageOperatorTypeSupport.md b/docs/ORTMobilePackageOperatorTypeSupport.md deleted file mode 100644 index 6a69a2c59882..000000000000 --- a/docs/ORTMobilePackageOperatorTypeSupport.md +++ /dev/null @@ -1,132 +0,0 @@ -# ONNX Runtime Mobile Pre-Built Package Operator and Type Support - -## Supported operators and types - -The supported operators and types are based on what is required to support float32 and quantized versions of popular models. The full list of input models used to determine this list is available [here](https://github.com/microsoft/onnxruntime/blob/main/tools/ci_build/github/android/mobile_package.required_operators.readme.txt) - -## Supported data input types - - - float - - int8_t - - uint8_t - -NOTE: Operators used to manipulate dimensions and indices will support int32 and int64. - -## Supported Operators - -|Operator|Opsets| -|--------|------| -|**ai.onnx**|| -|ai.onnx:Abs|12, 13, 14, 15| -|ai.onnx:Add|12, 13, 14, 15| -|ai.onnx:And|12, 13, 14, 15| -|ai.onnx:ArgMax|12, 13, 14, 15| -|ai.onnx:ArgMin|12, 13, 14, 15| -|ai.onnx:AveragePool|12, 13, 14, 15| -|ai.onnx:Cast|12, 13, 14, 15| -|ai.onnx:Ceil|12, 13, 14, 15| -|ai.onnx:Clip|12, 13, 14, 15| -|ai.onnx:Concat|12, 13, 14, 15| -|ai.onnx:ConstantOfShape|12, 13, 14, 15| -|ai.onnx:Conv|12, 13, 14, 15| -|ai.onnx:ConvTranspose|12, 13, 14, 15| -|ai.onnx:Cos|12, 13, 14, 15| -|ai.onnx:CumSum|12, 13, 14, 15| -|ai.onnx:DepthToSpace|12, 13, 14, 15| -|ai.onnx:DequantizeLinear|12, 13, 14, 15| -|ai.onnx:Div|12, 13, 14, 15| -|ai.onnx:DynamicQuantizeLinear|12, 13, 14, 15| -|ai.onnx:Elu|12, 13, 14, 15| -|ai.onnx:Equal|12, 13, 14, 15| -|ai.onnx:Erf|12, 13, 14, 15| -|ai.onnx:Exp|12, 13, 14, 15| -|ai.onnx:Expand|12, 13, 14, 15| -|ai.onnx:Flatten|12, 13, 14, 15| -|ai.onnx:Floor|12, 13, 14, 15| -|ai.onnx:Gather|12, 13, 14, 15| -|ai.onnx:GatherND|12, 13, 14, 15| -|ai.onnx:Gemm|12, 13, 14, 15| -|ai.onnx:GlobalAveragePool|12, 13, 14, 15| -|ai.onnx:Greater|12, 13, 14, 15| -|ai.onnx:GreaterOrEqual|12, 13, 14, 15| -|ai.onnx:HardSigmoid|12, 13, 14, 15| -|ai.onnx:Identity|12, 13, 14, 15| -|ai.onnx:If|12, 13, 14, 15| -|ai.onnx:InstanceNormalization|12, 13, 14, 15| -|ai.onnx:LRN|12, 13, 14, 15| -|ai.onnx:LayerNormalization|1| -|ai.onnx:LeakyRelu|12, 13, 14, 15| -|ai.onnx:Less|12, 13, 14, 15| -|ai.onnx:LessOrEqual|12, 13, 14, 15| -|ai.onnx:Log|12, 13, 14, 15| -|ai.onnx:LogSoftmax|12, 13, 14, 15| -|ai.onnx:Loop|12, 13, 14, 15| -|ai.onnx:MatMul|12, 13, 14, 15| -|ai.onnx:MatMulInteger|12, 13, 14, 15| -|ai.onnx:Max|12, 13, 14, 15| -|ai.onnx:MaxPool|12, 13, 14, 15| -|ai.onnx:Mean|12, 13, 14, 15| -|ai.onnx:Min|12, 13, 14, 15| -|ai.onnx:Mul|12, 13, 14, 15| -|ai.onnx:Neg|12, 13, 14, 15| -|ai.onnx:NonMaxSuppression|12, 13, 14, 15| -|ai.onnx:NonZero|12, 13, 14, 15| -|ai.onnx:Not|12, 13, 14, 15| -|ai.onnx:Or|12, 13, 14, 15| -|ai.onnx:PRelu|12, 13, 14, 15| -|ai.onnx:Pad|12, 13, 14, 15| -|ai.onnx:Pow|12, 13, 14, 15| -|ai.onnx:QLinearConv|12, 13, 14, 15| -|ai.onnx:QLinearMatMul|12, 13, 14, 15| -|ai.onnx:QuantizeLinear|12, 13, 14, 15| -|ai.onnx:Range|12, 13, 14, 15| -|ai.onnx:Reciprocal|12, 13, 14, 15| -|ai.onnx:ReduceMax|12, 13, 14, 15| -|ai.onnx:ReduceMean|12, 13, 14, 15| -|ai.onnx:ReduceMin|12, 13, 14, 15| -|ai.onnx:ReduceProd|12, 13, 14, 15| -|ai.onnx:ReduceSum|12, 13, 14, 15| -|ai.onnx:Relu|12, 13, 14, 15| -|ai.onnx:Reshape|12, 13, 14, 15| -|ai.onnx:Resize|12, 13, 14, 15| -|ai.onnx:ReverseSequence|12, 13, 14, 15| -|ai.onnx:Round|12, 13, 14, 15| -|ai.onnx:Scan|12, 13, 14, 15| -|ai.onnx:ScatterND|12, 13, 14, 15| -|ai.onnx:Shape|12, 13, 14, 15| -|ai.onnx:Sigmoid|12, 13, 14, 15| -|ai.onnx:Sin|12, 13, 14, 15| -|ai.onnx:Size|12, 13, 14, 15| -|ai.onnx:Slice|12, 13, 14, 15| -|ai.onnx:Softmax|12, 13, 14, 15| -|ai.onnx:SpaceToDepth|12, 13, 14, 15| -|ai.onnx:Split|12, 13, 14, 15| -|ai.onnx:Sqrt|12, 13, 14, 15| -|ai.onnx:Squeeze|12, 13, 14, 15| -|ai.onnx:Sub|12, 13, 14, 15| -|ai.onnx:Sum|12, 13, 14, 15| -|ai.onnx:Tanh|12, 13, 14, 15| -|ai.onnx:ThresholdedRelu|12, 13, 14, 15| -|ai.onnx:Tile|12, 13, 14, 15| -|ai.onnx:TopK|12, 13, 14, 15| -|ai.onnx:Transpose|12, 13, 14, 15| -|ai.onnx:Unique|12, 13, 14, 15| -|ai.onnx:Unsqueeze|12, 13, 14, 15| -|ai.onnx:Where|12, 13, 14, 15| -||| -|**com.microsoft**|| -|com.microsoft:DynamicQuantizeMatMul|1| -|com.microsoft:FusedConv|1| -|com.microsoft:FusedGemm|1| -|com.microsoft:FusedMatMul|1| -|com.microsoft:Gelu|1| -|com.microsoft:MatMulIntegerToFloat|1| -|com.microsoft:NhwcMaxPool|1| -|com.microsoft:QLinearAdd|1| -|com.microsoft:QLinearAveragePool|1| -|com.microsoft:QLinearConv|1| -|com.microsoft:QLinearGlobalAveragePool|1| -|com.microsoft:QLinearLeakyRelu|1| -|com.microsoft:QLinearMul|1| -|com.microsoft:QLinearSigmoid|1| -||| diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index ed944b5a6df7..529c676321bb 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -58,8 +58,8 @@ Do not modify directly.* |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| |Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float)| -|Clip|*in* input:**T**
*in* min:**T**
*in* max:**T**
*out* output:**T**

or

*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| -|||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| +|Clip|*in* input:**T**
*in* min:**T**
*in* max:**T**
*out* output:**T**

or

*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| +|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| |||11|**T** = tensor(float)| |||[6, 10]|**T** = tensor(float)| |Col2Im|*in* input:**T**
*in* image_shape:**tensor(int64)**
*in* block_shape:**tensor(int64)**
*out* output:**T**|18+|**T** = tensor(float)| @@ -272,18 +272,18 @@ Do not modify directly.* |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* output:**T**|11+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| |Reciprocal|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| -|ReduceL1|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[13, 17]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int64)| -|ReduceL2|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[13, 17]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int64)| -|ReduceLogSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[13, 17]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int64)| +|ReduceL1|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|ReduceL2|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|ReduceLogSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |ReduceLogSumExp|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| @@ -294,20 +294,20 @@ Do not modify directly.* |||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||11|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| -|ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32)| -|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32)| +|ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|20+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||[18, 19]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |||11|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| -|ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[13, 17]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(float), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int64)| +|ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |ReduceSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| @@ -1178,7 +1178,8 @@ Do not modify directly.* |||13+|**T** = tensor(float), tensor(float16)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|20+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||18+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(float), tensor(float16)| diff --git a/docs/python/README.rst b/docs/python/README.rst index 6c493e206a49..5a45bf6cef8e 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_. -**OpenVINO™ Execution Provider for ONNX Runtime** Linux Wheels comes with pre-built libraries of OpenVINO™ version 2023.0.0 eliminating the need to install OpenVINO™ separately. The OpenVINO™ libraries are prebuilt with CXX11_ABI flag set to 0. +**OpenVINO™ Execution Provider for ONNX Runtime** Linux Wheels comes with pre-built libraries of OpenVINO™ version 2024.1.0 eliminating the need to install OpenVINO™ separately. For more details on build and installation please refer to `Build `_. Usage ^^^^^ -By default, Intel® CPU is used to run inference. However, you can change the default option to either Intel® integrated or discrete GPU. +By default, Intel® CPU is used to run inference. However, you can change the default option to either Intel® integrated GPU, discrete GPU, integrated NPU (Windows only). Invoke `the provider config device type argument `_ to change the hardware on which inferencing is done. For more API calls and environment variables, see `Usage `_. diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 9289e14c17dd..c51f38553c3b 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1408,6 +1408,11 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi RuntimeOptimizationRecordContainer& MutableRuntimeOptimizations() { return runtime_optimizations_; } + + // We don't run Graph::Resolve() on an ORT format model, but a compiling EP may copy initializers to its + // compiled model during partitioning, leaving them unused in the ORT Graph. To allow the memory to be freed + // we need to manually run the cleanup that would usually happen as part of Graph::Resolve. + Status RemovedUnusedInitializersOrtFormat(); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // This friendship relationship should only be used to call Graph::Graph and @@ -1541,12 +1546,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi common::Status PerformTypeAndShapeInferencing(const ResolveOptions& options); - // Recursively find all subgraphs including nested subgraphs - void FindAllSubgraphs(std::vector& subgraphs); - - // Iterate this Graph instance and all subgraphs, calling the provided function for each. - common::Status ForThisAndAllSubgraphs(const std::vector& subgraphs, std::function func); - common::Status InferAndVerifyTypeMatch(Node& node, const ONNX_NAMESPACE::OpSchema& op, const ResolveOptions& options); // perform type and shape inferencing on the subgraph and Resolve to validate @@ -1576,9 +1575,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // Implementation for initializer replacement Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external); - // Clear all unused initializers and NodeArgs - void CleanUnusedInitializersAndNodeArgs(const std::unordered_set* initializer_names_to_preserve = nullptr); - std::vector CreateNodeArgs(const google::protobuf::RepeatedPtrField& names, const ArgNameToTypeMap& name_to_type_map); @@ -1587,6 +1583,16 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi #endif // !defined(ORT_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + + // Recursively find all subgraphs including nested subgraphs + void FindAllSubgraphs(std::vector& subgraphs); + + // Iterate this Graph instance and all subgraphs, calling the provided function for each. + common::Status ForThisAndAllSubgraphs(const std::vector& subgraphs, std::function func); + + // Clear all unused initializers and NodeArgs + void CleanUnusedInitializersAndNodeArgs(const std::unordered_set* initializer_names_to_preserve = nullptr); + Status PopulateNodeArgToProducerConsumerLookupsFromNodes(); template diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index 0bb5c7432f0a..6cff153c336f 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -3,12 +3,15 @@ #pragma once +#include #include +#include #include #include #include "core/common/inlined_containers.h" #include "core/framework/session_options.h" +#include "core/framework/tensor.h" #include "core/optimizer/graph_transformer.h" #include "core/platform/threadpool.h" @@ -51,7 +54,8 @@ InlinedVector> GenerateTransformers( const SessionOptions& session_options, const IExecutionProvider& execution_provider /*required by constant folding*/, const InlinedHashSet& rules_and_transformers_to_disable = {}, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + std::unordered_map>* p_buffered_tensors = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) @@ -81,7 +85,8 @@ InlinedVector> GenerateTransformersForMinimalB const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, const InlinedHashSet& rules_and_transformers_to_disable = {}, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + std::unordered_map>* p_buffered_tensors = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 462b31bb433a..12a19759bb97 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -40,6 +40,7 @@ struct CudaContext : public CustomOpContext { bool enable_skip_layer_norm_strict_mode = false; bool prefer_nhwc = false; bool use_tf32 = true; + bool fuse_conv_bias = true; void Init(const OrtKernelContext& kernel_ctx) { cuda_stream = FetchResource(kernel_ctx, CudaResource::cuda_stream_t); @@ -57,6 +58,7 @@ struct CudaContext : public CustomOpContext { kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t); prefer_nhwc = FetchResource(kernel_ctx, CudaResource::prefer_nhwc_t); use_tf32 = FetchResource(kernel_ctx, CudaResource::use_tf32_t); + fuse_conv_bias = FetchResource(kernel_ctx, CudaResource::fuse_conv_bias_t); } template diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index 01a14de699dc..3b7a1e99346d 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -38,5 +38,6 @@ struct OrtCUDAProviderOptionsV2 { int prefer_nhwc = 0; // make the CUDA EP NHWC preferred int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not int use_tf32 = 1; // use TF32 + int fuse_conv_bias = 0; // Enable CUDNN Frontend kernel fusing, results in JIT compiles int sdpa_kernel = 0; // Scaled Dot Product Attention kernel option }; diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h index 555023c442c0..b248d33035bc 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_resource.h +++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h @@ -19,4 +19,5 @@ enum CudaResource : int { enable_skip_layer_norm_strict_mode_t, prefer_nhwc_t, use_tf32_t, + fuse_conv_bias_t }; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5aafdd149e88..234574503c4b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -38,7 +38,7 @@ * * This value is used by some API functions to behave as this version of the header expects. */ -#define ORT_API_VERSION 19 +#define ORT_API_VERSION 20 #ifdef __cplusplus extern "C" { diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 5d974e1ff518..29a229f42716 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2216,7 +2216,7 @@ struct ShapeInferContext { size_t GetInputCount() const { return input_shapes_.size(); } - Status SetOutputShape(size_t indice, const Shape& shape); + Status SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); int64_t GetAttrInt(const char* attr_name); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index aaef111b9f15..9b9dd81a749c 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1998,9 +1998,10 @@ inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api, } } -inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape) { +inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type) { OrtTensorTypeAndShapeInfo* info = {}; ORT_CXX_RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info)); + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetTensorElementType(info, type)); using InfoPtr = std::unique_ptr>; diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 17ae649e6f17..209fd4279cc9 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -265,6 +265,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p // "1": dump the EP context into the Onnx model. (default). static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; +// Specify the EPContext node name prefix to make it unique +// in case user need to merge/connect multiple EPContext nodes in one model +static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix"; + // Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul. // Option values: // - "0": Gemm FastMath mode is not enabled. [DEFAULT] diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index dbb5f8118363..1a87569a115a 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -74,6 +74,12 @@ export declare namespace Env { */ wasmPaths?: WasmPrefixOrFilePaths; + /** + * Set a custom buffer which contains the WebAssembly binary. If this property is set, the `wasmPaths` property will + * be ignored. + */ + wasmBinary?: ArrayBufferLike|Uint8Array; + /** * Set or get a boolean value indicating whether to proxy the execution of main thread to a worker thread. * diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts index 43d539b38b6b..450ae2d06e63 100644 --- a/js/common/lib/version.ts +++ b/js/common/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.19.0'; +export const version = '1.20.0'; diff --git a/js/common/package-lock.json b/js/common/package-lock.json index 68a461aa518a..865fa860e98a 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-common", - "version": "1.19.0", + "version": "1.20.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-common", - "version": "1.19.0", + "version": "1.20.0", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/common/package.json b/js/common/package.json index ed008eeb4e75..9c941f6486ea 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -2,7 +2,7 @@ "license": "MIT", "type": "module", "name": "onnxruntime-common", - "version": "1.19.0", + "version": "1.20.0", "repository": { "url": "https://github.com/Microsoft/onnxruntime.git", "type": "git" diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts index 43d539b38b6b..450ae2d06e63 100644 --- a/js/node/lib/version.ts +++ b/js/node/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.19.0'; +export const version = '1.20.0'; diff --git a/js/node/package-lock.json b/js/node/package-lock.json index 8962731cdbcf..a0fc445c16dd 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-node", - "version": "1.19.0", + "version": "1.20.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-node", - "version": "1.19.0", + "version": "1.20.0", "hasInstallScript": true, "license": "MIT", "os": [ @@ -29,7 +29,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.19.0", + "version": "1.20.0", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/node/package.json b/js/node/package.json index 2bd24b7c4c25..4964d0fc3fd4 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -13,7 +13,7 @@ 3 ] }, - "version": "1.19.0", + "version": "1.20.0", "dependencies": { "onnxruntime-common": "file:../common", "tar": "^7.0.1" diff --git a/js/node/src/tensor_helper.cc b/js/node/src/tensor_helper.cc index 1c0b141e6a44..1062d89f76c5 100644 --- a/js/node/src/tensor_helper.cc +++ b/js/node/src/tensor_helper.cc @@ -38,13 +38,13 @@ constexpr size_t DATA_TYPE_ELEMENT_SIZE_MAP[] = { 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 4, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 - 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 INT64 not working in Javascript + 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING N/A 1, // ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL - 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 FLOAT16 not working in Javascript + 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE 4, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 - 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 UINT64 not working in Javascript + 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 not supported 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 not supported 0 // ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 not supported @@ -60,13 +60,13 @@ constexpr napi_typedarray_type DATA_TYPE_TYPEDARRAY_MAP[] = { napi_uint16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 napi_int16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 napi_int32_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 - napi_bigint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 INT64 not working i + napi_bigint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING not supported napi_uint8_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL - (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 FLOAT16 not working + napi_uint16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 FLOAT16 uses Uint16Array napi_float64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE napi_uint32_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 - napi_biguint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 UINT64 not working + napi_biguint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 not supported (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 not supported (napi_typedarray_type)(-1) // ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 not supported @@ -182,9 +182,7 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo * char *buffer = reinterpret_cast(tensorDataTypedArray.ArrayBuffer().Data()); size_t bufferByteOffset = tensorDataTypedArray.ByteOffset(); - // there is a bug in TypedArray::ElementSize(): https://github.com/nodejs/node-addon-api/pull/705 - // TODO: change to TypedArray::ByteLength() in next node-addon-api release. - size_t bufferByteLength = tensorDataTypedArray.ElementLength() * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]; + size_t bufferByteLength = tensorDataTypedArray.ByteLength(); return Ort::Value::CreateTensor(memory_info, buffer + bufferByteOffset, bufferByteLength, dims.empty() ? nullptr : &dims[0], dims.size(), elemType); } diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts index 43d539b38b6b..450ae2d06e63 100644 --- a/js/react_native/lib/version.ts +++ b/js/react_native/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.19.0'; +export const version = '1.20.0'; diff --git a/js/react_native/package.json b/js/react_native/package.json index d0f9790fcc87..20b5d02ff233 100644 --- a/js/react_native/package.json +++ b/js/react_native/package.json @@ -36,7 +36,7 @@ "registry": "https://registry.npmjs.org/" }, "source": "lib/index", - "version": "1.19.0", + "version": "1.20.0", "main": "dist/commonjs/index", "homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md", "files": [ diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index 02239284f893..99c03d2e7bf0 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -5254,7 +5254,7 @@ onetime@^5.1.0, onetime@^5.1.2: mimic-fn "^2.1.0" "onnxruntime-common@file:../common": - version "1.19.0" + version "1.20.0" open@^6.2.0: version "6.4.0" diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 8d077846fa6a..9934c758621c 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -13,8 +13,8 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim |:------:|:------:|:------:|:-:|:-:|:------| | Abs | ai.onnx(7-12, 13+) | abs | ✓ | ✓ | | | Add | ai.onnx(7-12, 13, 14+) | add | ✓ | ✓ | | -| ArgMax | ai.onnx(7-10, 11, 12, 13+) | argMax | ✓ | ✓ | WebNN CPU backend only supports 'select_last_index' value is 0 | -| ArgMin | ai.onnx(7-10, 11, 12, 13+) | argMin | ✓ | ✓ | WebNN CPU backend only supports 'select_last_index' value is 0 | +| ArgMax | ai.onnx(7-10, 11, 12, 13+) | argMax | ✓ | ✓ | | +| ArgMin | ai.onnx(7-10, 11, 12, 13+) | argMin | ✓ | ✓ | | | AveragePool | ai.onnx(7-9, 10, 11, 12-18, 19+) | averagePool2d | ✓ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'count_include_pad' value is 0 | | BatchNormalization | ai.onnx(7-8, 9-13, 14, 15+) | batchNormalization | ✓ | ✓ | Only supports 'training_mode' value is 0, one output | | Cast | ai.onnx(7-8, 9-12, 13-18, 19-20, 21+) | cast | ✓ | ✓ | WebNN CPU backend doesn't support casting to uint64 data type | @@ -22,9 +22,10 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Clip | ai.onnx(7-10, 11, 12, 13+) | clamp | ✓ | ✓ | WebNN CPU backend only supports 3 specific ranges: [0.0, infinity], [-1.0, 1.0], [0.0, 6.0] (Chromium issue: https://issues.chromium.org/issues/326156496) | | Concat | ai.onnx(7-10, 11-12, 13+) | concat | ✓ | ✓ | | | Conv | ai.onnx(7-10, 11+) | conv2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight) | -| ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✗ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). | +| ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group | | Cos | ai.onnx(7+) | cos | ✓ | ✓ | | | Div | ai.onnx(7-12, 13, 14+) | div | ✓ | ✓ | | +| Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity | ✓ | ✓ | Only supports test mode | | Elu | ai.onnx(7+) | elu | ✓ | ✓ | WebNN CPU backend only supports 'alpha' value is 1.0 | | Equal | ai.onnx(7-10, 11-12, 13-18, 19+) | equal | ✓ | ✓ | | | Erf | ai.onnx(7-9, 10-12, 13+) | erf | ✗ | ✓ | | diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts index 43d539b38b6b..450ae2d06e63 100644 --- a/js/web/lib/version.ts +++ b/js/web/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.19.0'; +export const version = '1.20.0'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts index f428293add59..a2e542838510 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts @@ -26,6 +26,9 @@ import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; import {createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvAttributes} from '../conv'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; + +import {typeSnippet} from './activation_util'; const arrayProduct = (arr: number[]) => { let product = 1; @@ -218,8 +221,8 @@ export const computeConv3DInfo = export const createConv3DNaiveProgramInfo = (inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[], filterDims: readonly number[], pads: readonly number[], dataFormat: string): ProgramInfo => { - const isChannelsLast = dataFormat === 'channelsLast'; - const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; + const isChannelLast = dataFormat === 'channelsLast'; + const inChannels = isChannelLast ? inputs[0].dims[3] : inputs[0].dims[1]; // TODO: enable vec4. const isVec4 = false; const workGroupSize: [number, number, number] = [64, 1, 1]; @@ -228,13 +231,14 @@ export const createConv3DNaiveProgramInfo = LOG_DEBUG('verbose', () => `[conv3d_naive_webgpu] dispatch = ${dispatch}`); - const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; + const innerElementSize = isVec4 ? (isChannelLast && inChannels % 4 !== 0 ? 3 : 4) : 1; const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: pads}, {type: DataType.uint32, data: attributes.strides}, {type: DataType.uint32, data: attributes.dilations} ]; + appendActivationUniformsData(attributes, programUniforms); programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const hasBias = inputs.length === 3; @@ -251,6 +255,7 @@ export const createConv3DNaiveProgramInfo = {name: 'strides', type: 'u32', length: attributes.strides.length}, {name: 'dilations', type: 'u32', length: attributes.dilations.length} ]; + appendActivationUniforms(attributes, uniforms); // TODO: support component 2, 3. const components = isVec4 ? 4 : 1; const t = tensorTypeToWsglStorageType(inputs[0].dataType); @@ -266,10 +271,12 @@ export const createConv3DNaiveProgramInfo = inputVariables.push(bias); declareFunctions += ` fn getBiasByOutputCoords(coords : array) -> ${isVec4 ? `vec4<${t}>` : t} { - return bias[${isChannelsLast ? getElementAt('coords', 4, 5) : getElementAt('coords', 1, 5)}${ + return bias[${isChannelLast ? getElementAt('coords', 4, 5) : getElementAt('coords', 1, 5)}${ isVec4 ? '/ 4' : ''}]; }`; } + const resType = typeSnippet(innerElementSize, t); + const applyActivation = getActivationSnippet(attributes, resType, t); return ` ${declareFunctions} @@ -287,28 +294,28 @@ export const createConv3DNaiveProgramInfo = let coords = ${output.offsetToIndices('global_idx')}; let batch = ${getElementAt('coords', 0, x.rank)}; let d2 = ${ - isChannelsLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank)}; + isChannelLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank)}; let xFRCCorner = vec3(${ - isChannelsLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank)}, - ${isChannelsLast ? getElementAt('coords', 2, x.rank) : getElementAt('coords', 3, x.rank)}, + isChannelLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank)}, + ${isChannelLast ? getElementAt('coords', 2, x.rank) : getElementAt('coords', 3, x.rank)}, ${ - isChannelsLast ? getElementAt('coords', 3, x.rank) : - getElementAt('coords', 4, x.rank)}) * uniforms.strides - uniforms.pads; + isChannelLast ? getElementAt('coords', 3, x.rank) : + getElementAt('coords', 4, x.rank)}) * uniforms.strides - uniforms.pads; let xFCorner = xFRCCorner.x; let xRCorner = xFRCCorner.y; let xCCorner = xFRCCorner.z; let xShapeY = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 1, x.rank) : getElementAt('uniforms.x_shape', 2, x.rank)}; + isChannelLast ? getElementAt('uniforms.x_shape', 1, x.rank) : getElementAt('uniforms.x_shape', 2, x.rank)}; let xShapeZ = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 2, x.rank) : getElementAt('uniforms.x_shape', 3, x.rank)}; + isChannelLast ? getElementAt('uniforms.x_shape', 2, x.rank) : getElementAt('uniforms.x_shape', 3, x.rank)}; let xShapeW = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 3, x.rank) : getElementAt('uniforms.x_shape', 4, x.rank)}; + isChannelLast ? getElementAt('uniforms.x_shape', 3, x.rank) : getElementAt('uniforms.x_shape', 4, x.rank)}; let xShapeU = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 4, x.rank) : getElementAt('uniforms.x_shape', 1, x.rank)}; + isChannelLast ? getElementAt('uniforms.x_shape', 4, x.rank) : getElementAt('uniforms.x_shape', 1, x.rank)}; let inputDepthNearestVec4 = (xShapeU / 4) * 4; let inputDepthVec4Remainder = xShapeU % 4; - var dotProd = 0.0; + var value = 0.0; for (var wF = 0u; wF < uniforms.filter_dims[0]; wF++) { let xF = xFCorner + wF * uniforms.dilations[0]; if (xF < 0 || xF >= xShapeY) { @@ -329,13 +336,13 @@ export const createConv3DNaiveProgramInfo = for (var d1 = 0u; d1 < inputDepthNearestVec4; d1 += 4) { ${ - isChannelsLast ? `let xValues = vec4( + isChannelLast ? `let xValues = vec4( getX(batch, xF, xR, xC, d1), getX(batch, xF, xR, xC, d1 + 1), getX(batch, xF, xR, xC, d1 + 2), getX(batch, xF, xR, xC, d1 + 3)); ` : - `let xValues = vec4( + `let xValues = vec4( getX(batch, d1, xF, xR, xC), getX(batch, d1 + 1, xF, xR, xC), getX(batch, d1 + 2, xF, xR, xC), @@ -346,36 +353,36 @@ export const createConv3DNaiveProgramInfo = getW(d2, d1 + 1, wF, wR, wC), getW(d2, d1 + 2, wF, wR, wC), getW(d2, d1 + 3, wF, wR, wC)); - dotProd += dot(xValues, wValues); + value += dot(xValues, wValues); } if (inputDepthVec4Remainder == 1) { ${ - isChannelsLast ? `dotProd += getX(batch, xF, xR, xC, inputDepthNearestVec4) + isChannelLast ? `value += getX(batch, xF, xR, xC, inputDepthNearestVec4) * getW(d2, inputDepthNearestVec4, wF, wR, wC);` : - `dotProd += getX(batch, inputDepthNearestVec4, xF, xR, xC) + `value += getX(batch, inputDepthNearestVec4, xF, xR, xC) * getW(d2, inputDepthNearestVec4, wF, wR, wC);`} } else if (inputDepthVec4Remainder == 2) { ${ - isChannelsLast ? `let xValues = vec2( + isChannelLast ? `let xValues = vec2( getX(batch, xF, xR, xC, inputDepthNearestVec4), getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1)); ` : - `let xValues = vec2( + `let xValues = vec2( getX(batch, inputDepthNearestVec4, xF, xR, xC), getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC)); `} let wValues = vec2( getW(d2, inputDepthNearestVec4, wF, wR, wC), getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC)); - dotProd += dot(xValues, wValues); + value += dot(xValues, wValues); } else if (inputDepthVec4Remainder == 3) { ${ - isChannelsLast ? `let xValues = vec3( + isChannelLast ? `let xValues = vec3( getX(batch, xF, xR, xC, inputDepthNearestVec4), getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1), getX(batch, xF, xR, xC, inputDepthNearestVec4 + 2)); ` : - `let xValues = vec3( + `let xValues = vec3( getX(batch, inputDepthNearestVec4, xF, xR, xC), getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC), getX(batch, inputDepthNearestVec4 + 2, xF, xR, xC)); @@ -384,19 +391,20 @@ export const createConv3DNaiveProgramInfo = getW(d2, inputDepthNearestVec4, wF, wR, wC), getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC), getW(d2, inputDepthNearestVec4 + 2, wF, wR, wC)); - dotProd += dot(xValues, wValues); + value += dot(xValues, wValues); } } } } - ${hasBias ? 'dotProd = dotProd + getBiasByOutputCoords(coords)' : ''}; - result[global_idx] = f32(dotProd); + ${hasBias ? 'value = value + getBiasByOutputCoords(coords)' : ''}; + ${applyActivation} + result[global_idx] = f32(value); }`; }; return { name: 'Conv3DNaive', shaderCache: - {hint: `${attributes.cacheKey};${isChannelsLast};${innerElementSize};${hasBias}`, inputDependencies}, + {hint: `${attributes.cacheKey};${isChannelLast};${innerElementSize};${hasBias}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 6e66abacf347..cfa0b42ef9ee 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -30,6 +30,10 @@ export const getActivationSnippet = baseType}(uniforms.beta)));`; case 'LeakyRelu': return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`; + case 'Tanh': + return `let e2x = exp(-2.0 * abs(value)); + value = sign(value) * (1.0 - e2x) / (1.0 + e2x); + `; case '': return ''; // TODO: adding other activations that can be fused. diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index fb068ab42d04..0f5f10716a00 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -108,6 +108,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const mjsPathOverride = (mjsPathOverrideFlag as URL)?.href ?? mjsPathOverrideFlag; const wasmPathOverrideFlag = (wasmPaths as Env.WasmFilePaths)?.wasm; const wasmPathOverride = (wasmPathOverrideFlag as URL)?.href ?? wasmPathOverrideFlag; + const wasmBinaryOverride = flags.wasmBinary; const [objectUrl, ortWasmFactory] = (await importWasmModule(mjsPathOverride, wasmPrefixOverride, numThreads > 1)); @@ -135,7 +136,12 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise numThreads, }; - if (wasmPathOverride || wasmPrefixOverride) { + if (wasmBinaryOverride) { + /** + * Set a custom buffer which contains the WebAssembly binary. This will skip the wasm file fetching. + */ + config.wasmBinary = wasmBinaryOverride; + } else if (wasmPathOverride || wasmPrefixOverride) { /** * A callback function to locate the WebAssembly file. The function should return the full path of the file. * diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 3cfc0457c623..1d3b7f161c28 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-web", - "version": "1.19.0", + "version": "1.20.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-web", - "version": "1.19.0", + "version": "1.20.0", "license": "MIT", "dependencies": { "flatbuffers": "^1.12.0", @@ -50,7 +50,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.19.0", + "version": "1.20.0", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/web/package.json b/js/web/package.json index b4f59902097a..11e18a5ae170 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -7,7 +7,7 @@ "type": "git" }, "author": "fs-eire", - "version": "1.19.0", + "version": "1.20.0", "jsdelivr": "dist/ort.min.js", "dependencies": { "flatbuffers": "^1.12.0", diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index 6a10e3b96a26..d88c91ebc9de 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -430,5 +430,38 @@ ] } ] + }, + { + "name": "fused conv with tanh", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Tanh", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [0.11, 0.12, 0.13, 0.14], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.15572261810302734, 0.20409323275089264, 0.29770541191101074, 0.3425688147544861], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/data/ops/fused-conv3dncdhw.jsonc b/js/web/test/data/ops/fused-conv3dncdhw.jsonc new file mode 100644 index 000000000000..1801ca380aa0 --- /dev/null +++ b/js/web/test/data/ops/fused-conv3dncdhw.jsonc @@ -0,0 +1,112 @@ +[ + { + "name": "fused conv3d with relu, x=[1, 1, 2, 1, 2], f=[2, 1, 2, 1, 2], s=1, d=1, p=valid, relu", + "operator": "FusedConv", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" }, + { "name": "auto_pad", "data": "VALID", "type": "string" }, + { "name": "strides", "data": [1, 1, 1], "type": "ints" }, + { "name": "dilations", "data": [1, 1, 1], "type": "ints" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.25, 0.5, 0.75, 1], + "dims": [1, 1, 2, 1, 2], + "type": "float32" + }, + { + "data": [-0.125, -0.25, -0.375, 0.5, 0.625, -0.75, -0.875, -1], + "dims": [2, 1, 2, 1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.0625, 0], + "dims": [1, 2, 1, 1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused conv3d with clip", + "operator": "FusedConv", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "activation", "data": "Clip", "type": "string" }, + { "name": "activation_params", "data": [1.0, 3.0], "type": "floats" }, + { "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" }, + { "name": "auto_pad", "data": "VALID", "type": "string" }, + { "name": "strides", "data": [1, 1, 1], "type": "ints" }, + { "name": "dilations", "data": [1, 1, 1], "type": "ints" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.25, 0.5, 0.75, 1], + "dims": [1, 1, 2, 1, 2], + "type": "float32" + }, + { + "data": [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1], + "dims": [2, 1, 2, 1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2.1875], + "dims": [1, 2, 1, 1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused conv3d with HardSigmoid, x=[1, 1, 2, 1, 2], f=[2, 1, 2, 1, 2], s=1, d=1, p=valid, relu", + "operator": "FusedConv", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "activation_params", "data": [0.1, 0.3], "type": "floats" }, + { "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" }, + { "name": "auto_pad", "data": "VALID", "type": "string" }, + { "name": "strides", "data": [1, 1, 1], "type": "ints" }, + { "name": "dilations", "data": [1, 1, 1], "type": "ints" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.25, 0.5, 0.75, 1], + "dims": [1, 1, 2, 1, 2], + "type": "float32" + }, + { + "data": [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1], + "dims": [2, 1, 2, 1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.39375001192092896, 0.518750011920929], + "dims": [1, 2, 1, 1, 1], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/e2e/browser-test-wasm-binary-override.js b/js/web/test/e2e/browser-test-wasm-binary-override.js new file mode 100644 index 000000000000..35d427fa3b72 --- /dev/null +++ b/js/web/test/e2e/browser-test-wasm-binary-override.js @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +const documentUrl = document.currentScript.src; + +it('Browser E2E testing - WebAssembly backend', async function() { + // preload .wasm file binary + const wasmUrl = new URL('./node_modules/onnxruntime-web/dist/ort-wasm-simd-threaded.wasm', documentUrl).href; + const response = await fetch(wasmUrl); + + // make sure the .wasm file is loaded successfully + assert(response.ok); + assert(response.headers.get('Content-Type') === 'application/wasm'); + + // override wasm binary + const binary = await response.arrayBuffer(); + ort.env.wasm.wasmBinary = binary; + + await testFunction(ort, {executionProviders: ['wasm']}); +}); diff --git a/js/web/test/e2e/run-data.js b/js/web/test/e2e/run-data.js index 507192f29be9..856f29eac6dd 100644 --- a/js/web/test/e2e/run-data.js +++ b/js/web/test/e2e/run-data.js @@ -36,6 +36,9 @@ const BROWSER_TEST_CASES = [ [true, false, './browser-test-wasm.js', 'ort.bundle.min.mjs', ['num_threads=2', 'proxy=1']], // 2 threads, proxy [true, false, './browser-test-wasm.js', 'ort.bundle.min.mjs', ['num_threads=1', 'proxy=1']], // 1 thread, proxy + // wasm binary override: + [true, false, './browser-test-wasm-binary-override.js', 'ort.min.js'], + // path override: // wasm, path override filenames for both mjs and wasm, same origin [true, false, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=9876', 'files=mjs,wasm']], diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 4a3a23bfe91b..4aaf9d16b2b0 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1358,6 +1358,7 @@ "fast-gelu.jsonc", "floor.jsonc", "fused-conv.jsonc", + "fused-conv3dncdhw.jsonc", "gather-elements.jsonc", "gemm.jsonc", "global-average-pool.jsonc", diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 944740a4ccad..e4d85c9d7b97 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -7,7 +7,7 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_ or the `Github project `_. """ -__version__ = "1.19.0" +__version__ = "1.20.0" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index 515a967aa238..f7d8fedc734e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -258,7 +258,6 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, output_parameters->scale = scale_; output_parameters->mask_type = mask_type; output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias; - output_parameters->pass_past_in_kv = false; output_parameters->qkv_format = Q_K_V_BNSH; } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 55292b35e1e3..88127387d08e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -6,6 +6,12 @@ namespace onnxruntime { namespace contrib { +enum AttentionType { + kAttention, + kMultiHeadAttention, + kDecoderMaskedMultiHeadAttention, +}; + enum AttentionMaskType { MASK_NONE, // No mask MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length @@ -24,10 +30,12 @@ enum AttentionQkvFormat { UNKNOWN, // enum value not set, or depends on qkv projection implementation details Q_K_V_BNSH, // for non-packed qkv, permuted Q_K_V_BSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention - QKV_BSN3H, // for TRT fused attention, qkv are packed + Q_K_V_BSNH_BNSH_BNSH, // for cross attention, k and v are permuted Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) - Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed Q_K_V_TNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed + QKV_BSN3H, // for TRT fused attention, qkv are packed + QKV_BS3NH, // for DecoderMaskedMultiHeadAttention, qkv are packed QKV_TN3H, // for TRT fused attention, qkv are packed and paddings are removed }; @@ -61,7 +69,6 @@ struct AttentionParameters { bool past_present_share_buffer; bool do_rotary; bool broadcast_res_pos_bias; - bool pass_past_in_kv; float mask_filter_value; float scale; bool use_tf32; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 9677c30f22d8..0d7737677923 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -85,7 +85,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { scale_, is_unidirectional_, past_present_share_buffer, - false)); + kMultiHeadAttention)); const int batch_size = parameters.batch_size; const int q_sequence_length = parameters.sequence_length; @@ -121,20 +121,13 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - // For each of Q/K/V, there are multiple scenarios: - // 1) Combined QKV bias is null - // a) Q/K/V is (B, S, D) - // b) Q/K/V is (B, S, N, H) - // 2) No packed QKV in Q - // a) Q/K/V has seq_len = 1 - // b) Q/K/V has seq_len > 1 - OrtValue Q; ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( context, allocator, batch_size, num_heads_, q_sequence_length, qk_head_size, query, bias, q_bias_offset, Q)); - if (parameters.pass_past_in_kv) { // key and value in BNSH format - assert(bias == nullptr); + if (parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH) { + // For cross attention with k and v in BNSH format, we assume that bias for key and value are zeros. + // So we don't need to add bias for key and value here. assert(past_key == nullptr); assert(past_value == nullptr); return ApplyAttention(Q.GetMutable()->MutableData(), diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index bd7ab0965917..cfb8d3684377 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -11,6 +11,232 @@ namespace onnxruntime { namespace contrib { namespace multihead_attention_helper { +template +Status Check_QKV(const T* packed_qkv, AttentionQkvFormat& qkv_format) { + const auto& query_dims = packed_qkv->Shape().GetDims(); + if (query_dims.size() == 3) { + // Packed qkv used by DecoderMaskedMultiHeadAttention. Query shape is (B, S, 3D), no key and value. + qkv_format = AttentionQkvFormat::QKV_BS3NH; + } else { + assert(query_dims.size() == 5); + if (static_cast(query_dims[3]) != 3) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Expect 'query' shape (batch_size, sequence_length, num_heads, 3, head_size) for packed qkv"); + } + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } + + return Status::OK(); +} + +template +Status Check_Q_KV(const T* query, const T* packed_kv, int num_heads, int head_size, + AttentionQkvFormat& qkv_format, int& kv_sequence_length) { + const auto& query_dims = query->Shape().GetDims(); + const auto& key_dims = packed_kv->Shape().GetDims(); + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of query be 3 for packed kv"); + } + + if (key_dims.size() != 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of key be 5 for packed kv"); + } + + if (key_dims[0] != query_dims[0] || + static_cast(key_dims[2]) != num_heads || + static_cast(key_dims[3]) != 2 || + static_cast(key_dims[4]) != head_size) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv"); + } + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + kv_sequence_length = static_cast(key_dims[1]); + return Status::OK(); +} + +template +Status Check_Q_K_V(const T* query, const T* key, const T* value, int num_heads, int head_size, + AttentionQkvFormat& qkv_format, int& kv_sequence_length, int& v_hidden_size) { + const auto& query_dims = query->Shape().GetDims(); + const auto& key_dims = key->Shape().GetDims(); + const auto& value_dims = value->Shape().GetDims(); + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of query be 3 for packed kv"); + } + + if (key_dims.size() != value_dims.size() || (key_dims.size() != 3 && value_dims.size() != 4)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of key and value be same, and either 3 or 4"); + } + + if (key_dims[0] != query_dims[0] || value_dims[0] != query_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query', 'key' and 'value' shall have same dim 0 (batch_size)"); + } + + if (key_dims.size() == 3) { + if (key_dims[2] != query_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 2 (hidden_size)"); + } + + if (key_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall have same dim 1 (kv_sequence_length)"); + } + + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + kv_sequence_length = static_cast(key_dims[1]); + v_hidden_size = static_cast(value_dims[2]); + } else { // key_dims.size() == 4 + if (value->Shape() != key->Shape() || + static_cast(key_dims[1]) != num_heads || + static_cast(key_dims[3]) != head_size) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall have same shape (batch_size, num_heads, kv_sequence_length, head_size)"); + } + + qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + kv_sequence_length = static_cast(key_dims[2]); + v_hidden_size = static_cast(value_dims[1]) * static_cast(value_dims[3]); + } + + return Status::OK(); +} + +template +Status CheckPast(const T* past_key, const T* past_value, const T* past_seq_len, + int batch_size, int num_heads, int head_size, bool past_present_share_buffer, + int& past_sequence_length, int& max_sequence_length) { + const auto& past_key_dims = past_key->Shape().GetDims(); + const auto& past_value_dims = past_value->Shape().GetDims(); + + if (past_key_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' is expected to have 4 dimensions, got ", + past_key_dims.size()); + } + if (past_value_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' is expected to have 4 dimensions, got ", + past_value_dims.size()); + } + + if (past_key_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 0 should be batch_size, got ", + past_key_dims[0]); + } + if (past_value_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 0 should be batch_size, got ", + past_value_dims[0]); + } + + if (past_key_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 1 should be same as number of heads, got ", + past_key_dims[1]); + } + if (past_value_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 1 should be same as number of heads, got ", + past_value_dims[1]); + } + if (past_key_dims[2] != past_value_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ", + past_key_dims[2], " vs ", past_value_dims[2]); + } + if (past_key_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 3 should be same as head_size, got ", + past_key_dims[3]); + } + if (past_value_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 3 should be same as head_size, got ", + past_value_dims[3]); + } + past_sequence_length = static_cast(past_key_dims[2]); + if (past_present_share_buffer) { + max_sequence_length = static_cast(past_key_dims[2]); + if (past_seq_len == nullptr || !onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "past_sequence_length tensor must be of one element when past_present_share_buffer is set"); + } + past_sequence_length = *((*past_seq_len).template Data()); + } + return Status::OK(); +} + +template +Status CheckRelativePositionBias( + const T* relative_position_bias, int batch_size, int num_heads, int sequence_length, int total_sequence_length, + bool& broadcast_res_pos_bias) { + const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); + + if (relative_position_bias_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' is expected to have 4 dimensions, got ", + relative_position_bias_dims.size()); + } + if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", + relative_position_bias_dims[0]); + } + if (relative_position_bias_dims[0] == 1) { + broadcast_res_pos_bias = true; + } + if (relative_position_bias_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", + relative_position_bias_dims[1]); + } + if (relative_position_bias_dims[2] != sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", + relative_position_bias_dims[2]); + } + if (relative_position_bias_dims[3] != total_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", + relative_position_bias_dims[3]); + } + return Status::OK(); +} + +template +AttentionMaskType GetMaskType(const T* key_padding_mask, int batch_size, int sequence_length, int total_sequence_length) { + AttentionMaskType mask_type = AttentionMaskType::MASK_UNKNOWN; + const auto& mask_dims = key_padding_mask->Shape().GetDims(); + if (mask_dims.size() == 1) { + if (mask_dims[0] == static_cast(batch_size)) { + mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN; + } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) { + mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; + } + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(total_sequence_length)) { + mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(sequence_length) && + mask_dims[2] == static_cast(total_sequence_length)) { + mask_type = AttentionMaskType::MASK_3D_ATTENTION; + } + return mask_type; +} + template Status CheckInputs(const T* query, const T* key, @@ -27,176 +253,128 @@ Status CheckInputs(const T* query, float scale, bool is_unidirectional, bool past_present_share_buffer, - bool dmmha_packing) { - // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None - // relative_position_bias : (B, 1, S, L) - // past_key : (B, N, S*, H) - // past_value : (B, N, S*, H) - // When no packing for q/k/v: + AttentionType operator_type) { + // --------------------------------------------------------------- + // Notations: + // B: batch_size + // N: num_heads + // H: head_size (V might have different head size than Q and K) + // D: hidden_size = N * H + // S: q_sequence_length + // P: past_sequence_length + // L: kv_sequence_length + // T: total_sequence_length = P + L + // M: max_sequence_length + // --------------------------------------------------------------- + // MultiHeadAttention inputs: + // --------------------------------------------------------------- + // Q_K_V_BSNH - no packing: // query (Q) : (B, S, D) - // key (K) : (B, L, D) or (B, N, S*, H) - // value (V) : (B, L, D_v) or (B, N, S*, H) - // bias (Q/K/V) : (D + D + D_v) - // When packed kv is used: + // key (K) : (B, L, D) + // value (V) : (B, L, D_v) + // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache is not used, L == T, D == D_v): // query (Q) : (B, S, D) - // key (K) : (B, L, N, 2, H) - // value (V) : None - // bias (Q/K/V) : None - // When packed qkv is used: - // query (Q) : (B, L, N, 3, H) or (B, S, 3*D) + // key (K) : (B, N, L, H) + // value (V) : (B, N, L, H) + // Q_KV_BSNH_BSN2H - packed kv (kv cache is not used, bias is not allowed for packed kv): + // query (Q) : (B, S, D) + // key (K/V) : (B, L, N, 2, H) + // value : None + // QKV_BSN3H - packed qkv (kv cache is not used, S == L, D == D_v): + // query (Q/K/V) : (B, S, N, 3, H) + // key : None + // value : None + // + // Other inputs: + // bias (Q/K/V) : None or (D + D + D_v) + // key_padding_mask (K/V) : (B) or (3 * B + 2) or (B, T) or (B, S, T) + // relative_position_bias : (B, N, S, T) or (1, N, S, T) + // past_key : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. + // past_value : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. + // --------------------------------------------------------------- + // DecoderMaskedMultiHeadAttention inputs (S == 1, D == D_v): + // --------------------------------------------------------------- + // Q_K_V_BSNH - no packing: + // query (Q) : (B, S, D) + // key (K) : (B, L, D) + // value (V) : (B, L, D) + // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache and relative_position_bias are not used. L == T): + // query (Q) : (B, S, D) + // key (K) : (B, N, L, H) + // value (V) : (B, N, L, H) + // QKV_BS3NH - packed qkv (S == L): + // query (Q) : (B, S, 3 * D) // key (K) : None // value (V) : None - // bias (Q/K/V) : None or (D + D + D_v) - - AttentionQkvFormat qkv_format; + // + // Other inputs: + // bias (Q/K/V) : None or (3 * D) + // key_padding_mask (K/V) : None or (B, T) + // relative_position_bias : (1, N, S, T), or (B, N, S, T) where only 1 x N x S x T data is used in CUDA. + // + // The following inputs are not used in cross attention (so they are None for cross attention): + // past_key : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True. + // For CUDA, past_present_share_buffer is always True. ROCm supports both. + // past_value : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True. + // For CUDA, past_present_share_buffer is always True. ROCm supports both. + // past_sequence_length : scalar (1) when past_present_share_buffer is True. + // CUDA version has extra inputs (beam_width, cache_indirection) that are not checked in the class. + // For ROCm, see contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh for more details. + // --------------------------------------------------------------- + AttentionQkvFormat qkv_format = UNKNOWN; const auto& query_dims = query->Shape().GetDims(); - if (query_dims.size() != 3 && query_dims.size() != 5) { + + int query_rank = static_cast(query_dims.size()); + if (query_rank != 3 && query_rank != 5) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions, got ", - query_dims.size()); + query_rank); } int batch_size = static_cast(query_dims[0]); int sequence_length = static_cast(query_dims[1]); - int hidden_size = (query_dims.size() == 3) + bool dmmha_packing = operator_type == kDecoderMaskedMultiHeadAttention && key == nullptr && value == nullptr; + int hidden_size = (query_rank == 3) ? (dmmha_packing ? (static_cast(query_dims[2]) / 3) : static_cast(query_dims[2])) : (num_heads * static_cast(query_dims[4])); int head_size = static_cast(hidden_size) / num_heads; int kv_sequence_length = sequence_length; + int v_hidden_size = hidden_size; + if (key != nullptr) { + if (value == nullptr) { + ORT_RETURN_IF_ERROR(Check_Q_KV(query, key, num_heads, head_size, qkv_format, kv_sequence_length)); + } else { + ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, head_size, + qkv_format, kv_sequence_length, v_hidden_size)); + } + } else if (value == nullptr) { // no key and value + ORT_RETURN_IF_ERROR(Check_QKV(query, qkv_format)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value' shall absent when 'key' is absent"); + } + int past_sequence_length = 0; int max_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { - const auto& past_key_dims = past_key->Shape().GetDims(); - const auto& past_value_dims = past_value->Shape().GetDims(); - - if (past_key_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' is expected to have 4 dimensions, got ", - past_key_dims.size()); - } - if (past_value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' is expected to have 4 dimensions, got ", - past_value_dims.size()); - } - - if (past_key_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 0 should be batch_size, got ", - past_key_dims[0]); - } - if (past_value_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 0 should be batch_size, got ", - past_value_dims[0]); - } - - if (past_key_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 1 should be same as number of heads, got ", - past_key_dims[1]); - } - if (past_value_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 1 should be same as number of heads, got ", - past_value_dims[1]); - } - if (past_key_dims[2] != past_value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ", - past_key_dims[2], " vs ", past_value_dims[2]); - } - if (past_key_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 3 should be same as head_size, got ", - past_key_dims[3]); - } - if (past_value_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 3 should be same as head_size, got ", - past_value_dims[3]); - } - past_sequence_length = static_cast(past_key_dims[2]); - max_sequence_length = static_cast(past_key_dims[2]); - if (past_present_share_buffer) { - if (past_seq_len == nullptr || !onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "past_sequence_length tensor must be of one element when past_present_share_buffer is set"); - } - past_sequence_length = *((*past_seq_len).template Data()); - } + ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, past_seq_len, + batch_size, num_heads, head_size, past_present_share_buffer, + past_sequence_length, max_sequence_length)); } else if (past_key != nullptr || past_value != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' and 'past_value' shall be both present or both absent"); } - if (key != nullptr) { - if (query_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions when key is given, got ", - query_dims.size()); - } - - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3 && key_dims.size() != 4 && key_dims.size() != 5) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3, 4, or 5 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { + if (operator_type == kMultiHeadAttention) { + if (qkv_format == AttentionQkvFormat::QKV_BS3NH) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); + "Packed qkv of 3D BS3NH format is not support by MultiHeadAttention"); } - if (key_dims.size() == 3) { - if (key_dims[2] != query_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 2 (hidden_size)"); - } - - qkv_format = Q_K_V_BSNH; - kv_sequence_length = static_cast(key_dims[1]); - } else if (key_dims.size() == 5) { - if (static_cast(key_dims[2]) != num_heads || static_cast(key_dims[3]) != 2 || static_cast(key_dims[4]) != head_size) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv"); - } - if (value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format."); - } - - qkv_format = Q_KV_BSNH_BSN2H; - kv_sequence_length = static_cast(key_dims[1]); - } else { // key_dims.size() == 4 (cross-attention with past_key) - if (static_cast(key_dims[1]) != num_heads || static_cast(key_dims[3]) != head_size) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Expect 'key' shape (batch_size, num_heads, kv_sequence_length, head_size)"); - } - - if (value == nullptr || value->Shape().GetDims().size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' shall be 4D when 'key' is 4D"); - } - - if (bias != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' shall be empty when 'key' is 4D"); - } - - qkv_format = UNKNOWN; - kv_sequence_length = static_cast(key_dims[2]); - } - } else { // packed QKV - if (query_dims.size() != 3 && query_dims.size() != 5) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions when key is empty, got ", - query_dims.size()); - } - if (query_dims.size() == 5 && (static_cast(query_dims[2]) != num_heads || static_cast(query_dims[3]) != 3)) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Expect 'query' shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv"); + if (qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H && bias != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' shall be empty when packed kv is used"); } - - qkv_format = QKV_BSN3H; } if (bias != nullptr) { @@ -206,116 +384,31 @@ Status CheckInputs(const T* query, bias_dims.size()); } - if (value == nullptr) { - // Currently, bias is not allowed for packed KV. This constraint can be removed later. - // Here we assume that fusion tool will not include bias for packed KV. - if (query_dims.size() == 5 && query_dims[3] == 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed kv. "); - } + int expected_bias_length = 2 * hidden_size + v_hidden_size; + if (bias_dims[0] != static_cast(expected_bias_length)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' length is expected to be 2 * hidden_size + hidden_size_v, got ", + bias_dims.size()); } } int total_sequence_length = past_sequence_length + kv_sequence_length; AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; if (key_padding_mask != nullptr) { - mask_type = AttentionMaskType::MASK_UNKNOWN; - const auto& mask_dims = key_padding_mask->Shape().GetDims(); - if (mask_dims.size() == 1) { - if (mask_dims[0] == static_cast(batch_size)) { - mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) { - mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; - } - } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && - mask_dims[1] == static_cast(kv_sequence_length)) { - mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; - } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && - mask_dims[1] == static_cast(total_sequence_length)) { - mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; - } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) && - mask_dims[1] == static_cast(sequence_length) && - mask_dims[2] == static_cast(total_sequence_length)) { - mask_type = AttentionMaskType::MASK_3D_ATTENTION; - } - + mask_type = GetMaskType(key_padding_mask, batch_size, sequence_length, total_sequence_length); if (mask_type == AttentionMaskType::MASK_UNKNOWN) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key_padding_mask' shape shall be 1D, 2D, or 3D"); - } - } - - // NOTE: In Cross-Attention, we pass the past key and value to 'key' and 'value' instead of 'past_key' and 'past_value'. - bool pass_past_in_kv = false; - int v_hidden_size = hidden_size; - if (value != nullptr) { - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3 && value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 or 4 dimensions, got ", - value_dims.size()); - } - - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } - - if (value_dims.size() == 3) { - if (static_cast(kv_sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); - } - v_hidden_size = static_cast(value_dims[2]); - } else { // value_dims.size() == 4 - if (static_cast(kv_sequence_length) != value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have the same dim 2 (kv_sequence_length)"); - } - - if (past_key != nullptr || past_value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall be empty when 'value' is 4D"); - } - - v_hidden_size = static_cast(value_dims[1]) * static_cast(value_dims[3]); - pass_past_in_kv = true; + "Input 'key_padding_mask' shape is not expected."); } } bool broadcast_res_pos_bias = false; if (relative_position_bias != nullptr) { - const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - - if (relative_position_bias_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' is expected to have 4 dimensions, got ", - relative_position_bias_dims.size()); - } - if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", - relative_position_bias_dims[0]); - } - if (relative_position_bias_dims[0] == 1) { - broadcast_res_pos_bias = true; - } - if (relative_position_bias_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", - relative_position_bias_dims[1]); - } - if (relative_position_bias_dims[2] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", - relative_position_bias_dims[2]); - } - if (relative_position_bias_dims[3] != total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", - relative_position_bias_dims[3]); - } + ORT_RETURN_IF_ERROR(CheckRelativePositionBias( + relative_position_bias, batch_size, num_heads, sequence_length, total_sequence_length, broadcast_res_pos_bias)); } - // TODO: ORT_RETURN_IF(qkv_format == UNKNOWN, "Unrecognized QKV format"); + assert(qkv_format != UNKNOWN); + if (parameters != nullptr) { AttentionParameters* output_parameters = reinterpret_cast(parameters); output_parameters->batch_size = batch_size; @@ -323,7 +416,7 @@ Status CheckInputs(const T* query, output_parameters->past_sequence_length = past_sequence_length; output_parameters->kv_sequence_length = kv_sequence_length; output_parameters->total_sequence_length = total_sequence_length; - output_parameters->max_sequence_length = max_sequence_length; + output_parameters->max_sequence_length = past_present_share_buffer ? max_sequence_length : total_sequence_length; output_parameters->input_hidden_size = 0; output_parameters->hidden_size = hidden_size; output_parameters->v_hidden_size = v_hidden_size; @@ -336,7 +429,6 @@ Status CheckInputs(const T* query, output_parameters->mask_type = mask_type; output_parameters->scale = scale; output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias; - output_parameters->pass_past_in_kv = pass_past_in_kv; output_parameters->qkv_format = qkv_format; } @@ -359,7 +451,7 @@ Status CheckInputs(const T* query, float scale, bool is_unidirectional, bool past_present_share_buffer, - bool dmmha_packing, + AttentionType operator_type, int max_threads_per_block) { if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); @@ -367,7 +459,7 @@ Status CheckInputs(const T* query, return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional, - past_present_share_buffer, dmmha_packing); + past_present_share_buffer, operator_type); } } // namespace multihead_attention_helper diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 995babc85735..5fdd2b017b8a 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -104,6 +104,8 @@ class MatMulNBits final : public OpKernel { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + const Tensor* tensor_zero_point = nullptr; + has_zp_input_ = info.TryGetConstantInput(3, &tensor_zero_point); #ifdef ORT_NEURAL_SPEED const Tensor* tensor_B = nullptr; const Tensor* tensor_scale = nullptr; @@ -139,6 +141,7 @@ class MatMulNBits final : public OpKernel { IAllocatorUniquePtr packed_b_{}; size_t packed_b_size_{0}; + bool has_zp_input_{false}; #if defined(ORT_NEURAL_SPEED) bool is_asym_{false}; @@ -207,10 +210,10 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } -#else // defined(ORT_NEURAL_SPEED) - +#else // defined(ORT_NEURAL_SPEED) + ORT_UNUSED_PARAMETER(prepacked_weights); + const auto compute_type = static_cast(accuracy_level_); if (input_idx == InputIndex::B) { - const auto compute_type = static_cast(accuracy_level_); if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { return Status::OK(); } @@ -220,12 +223,20 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get()); - if (prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr); is_packed = true; + } else if (compute_type == CompInt8) { +#ifdef MLAS_TARGET_AMD64_IX86 + if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + auto sptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); + is_packed = false; + } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); + is_packed = false; + } +#endif } #endif // defined(ORT_NEURAL_SPEED) @@ -332,9 +343,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); IAllocatorUniquePtr workspace{}; - if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, - nbits_, block_size_, compute_type); - workspace_size > 0) { + const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( + M, N, K, batch_count, nbits_, block_size_, compute_type); + if (workspace_size > 0) { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); @@ -344,14 +355,18 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { for (size_t i = 0; i < batch_count; ++i) { data[i].A = a_data + helper.LeftOffsets()[i]; data[i].lda = lda; - data[i].QuantBData = packed_b_.get(); +#ifdef MLAS_TARGET_AMD64_IX86 + if (compute_type == CompInt8) { + data[i].QuantBDataWorkspace = packed_b_.get(); + } +#endif + data[i].PackedQuantBData = static_cast(packed_b_.get()); data[i].QuantBScale = scales_data; data[i].QuantBZeroPoint = zero_points_data; data[i].Bias = bias_data; data[i].C = y_data + helper.OutputOffsets()[i]; data[i].ldc = N; } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), thread_pool); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc index 34a1da99316a..030cdb1e1b17 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc @@ -143,6 +143,8 @@ Status GptSubgraph::Validate(const std::vector& subgraph_inputs, // Past state shape is like (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads). const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape(); + ORT_RETURN_IF(past_shape == nullptr, + "subgraph past state cannot be nullptr"); ORT_RETURN_IF(past_shape->dim_size() != 5, "subgraph past state is expected to have 5 dimension, got ", past_shape->dim_size()); diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 9e6752b45186..62d6a723bf32 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -520,6 +520,39 @@ __global__ void AddBiasUnpack(int M, const T* input, const T* biases, T* output) } } +template +__global__ void AddBiasTransposeUnpack(int M, const T* input, const T* biases, T* output) { + // Format 5 to unpack TRT packed input format to BNSH for unfused attention. + // Input: BxSxNxMxH + // Output: MxBxNxSxH + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + int n = threadIdx.y; + int s = blockIdx.x; + int b = blockIdx.y; + int m = blockIdx.z; // matrix id + + const int head_size = blockDim.x; + const int num_heads = blockDim.y; + + const int sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int H = head_size; + const int NH = num_heads * head_size; + const int NHS = NH * sequence_length; + + int in_offset = m * head_size + n * M * H + (s * NH + b * NHS) * M; + const int out_offset = (s + n * sequence_length) * head_size + (b + m * batch_size) * NHS; + + const int h = threadIdx.x; + if (h < head_size) { + if (biases != nullptr) { + output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h]; + } else { + output[out_offset + h] = input[in_offset + h]; + } + } +} + template __global__ void AddBiasTransposeCutlass(int M, const T* input, const T* biases, T* output) { // Format 3 for cutlass memory efficient attention @@ -692,6 +725,8 @@ void InvokeAddBiasTranspose( } } else if (format == 4) { // format == 4 AddBiasUnpack<<>>(total_matrix_count, input, biases, output); + } else if (format == 5) { // format == 5 + AddBiasTransposeUnpack<<>>(total_matrix_count, input, biases, output); } else { // format == 0 AddBiasTranspose<<>>(input, biases, output); } @@ -716,6 +751,8 @@ void InvokeAddBiasTranspose( } } else if (format == 4) { // format == 4 ORT_THROW("AddBiasTranspose (format 4) not implemented for hidden_size > max_threads_per_block"); + } else if (format == 5) { // format == 5 + ORT_THROW("AddBiasTranspose (format 5) not implemented for hidden_size > max_threads_per_block"); } else { // format 0 AddBiasTransposeLarge<<>>(qk_head_size, input, biases, output); } @@ -904,6 +941,7 @@ void InvokeAddBias( AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); } } + // K { const dim3 grid(kv_sequence_length, batch_size, num_matrices); @@ -1011,6 +1049,82 @@ void LaunchAddBias( } } +template +void InvokeAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const T* biases, const T* query, T* q) { + assert(num_heads <= max_threads_per_block); + constexpr int num_matrices = 1; + const dim3 grid(sequence_length, batch_size, num_matrices); + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + AddBiasTransposeTrt<<>>(query, biases, q); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); + } +} + +template <> +void LaunchAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const float* biases, const float* query, float* q) { + if (0 == (head_size % 4)) { + const int H = head_size / 4; + const float4* query2 = reinterpret_cast(query); + const float4* biases2 = reinterpret_cast(biases); + float4* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else if (0 == (head_size & 1)) { + const int H = head_size / 2; + const float2* query2 = reinterpret_cast(query); + const float2* biases2 = reinterpret_cast(biases); + float2* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else { + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, head_size, + biases, query, q); + } +} + +template <> +void LaunchAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const half* biases, const half* query, half* q) { + if (0 == (head_size % 4)) { + const int H = head_size / 4; + const Half4* query2 = reinterpret_cast(query); + const Half4* biases2 = reinterpret_cast(biases); + Half4* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else if (0 == (head_size & 1)) { + const int H = head_size / 2; + const half2* query2 = reinterpret_cast(query); + const half2* biases2 = reinterpret_cast(biases); + half2* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else { + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, head_size, + biases, query, q); + } +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h index efc31db43bcd..bd4e123a272b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h @@ -3,14 +3,15 @@ #pragma once #include "core/providers/cuda/shared_inc/cuda_utils.h" +#include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { namespace contrib { namespace cuda { -// Fused kernel of Add (bias) and Transpose. +// Fused kernel of Add bias (optional, can be None) and Transpose. // Shape of inputs and outputs: -// biases: (num_matrices, num_heads * head_size) +// biases: (num_matrices, num_heads * head_size) or None // format 0: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3) // input: (num_matrices, batch_size, sequence_length, num_heads, head_size) // output: (num_matrices, batch_size, num_heads, sequence_length, head_size) @@ -24,9 +25,12 @@ namespace cuda { // format 3: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3) // input: (batch_size, sequence_length, num_matrices, num_heads, head_size) // output: (num_matrices, batch_size, sequence_length, num_heads, head_size) -// format 4: (requires qk_head_size = v_head_size) +// format 4: (requires qk_head_size == v_head_size) // input: (batch_size, sequence_length, num_heads, num_matrices, head_size) // output: (num_matrices, batch_size, sequence_length, num_heads, head_size) +// format 5: (requires qk_head_size == v_head_size) +// input: (batch_size, sequence_length, num_heads, num_matrices, head_size) +// output: (num_matrices, batch_size, num_heads, sequence_length, head_size) template void LaunchAddBiasTranspose( @@ -35,7 +39,7 @@ void LaunchAddBiasTranspose( const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size, T* qkv_add_bias = nullptr, int total_matrix_count = -1, bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0); -// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format. +// Add bias (optional, can be None) and Transpose for separated inputs of Q, K and V, and output Trt format. // For self attention: // output: (batch_size, sequence_length, num_heads, 3, head_size) // It assumes sequence_length == kv_sequence_length and head_size == v_head_size. @@ -50,7 +54,7 @@ void LaunchAddBiasTransposeTrt( const T* biases, const T* query, const T* key, const T* value, T* output, bool is_cross_attention, int kv_sequence_length = -1); -// Add (bias) for separated inputs of Q, K and V. +// Add bias (required) for separated inputs of Q, K and V. // Q: (batch_size, sequence_length, num_heads, head_size) // K: (batch_size, kv_sequence_length, num_heads, head_size) // V: (batch_size, kv_sequence_length, num_heads, v_head_size) @@ -61,6 +65,46 @@ void LaunchAddBias( const int num_heads, const int head_size, const int v_head_size, const T* biases, const T* query, const T* key, const T* value, T* q, T* k, T* v); +// Add bias (required) for Q: (batch_size, sequence_length, num_heads, head_size) +template +void LaunchAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const T* biases, const T* query, T* q); + +// Add bias (optional, can be None) transpose kernel defined in packed_multihead_attention_impl.cu. +// Support the following format transforms (for float and half only). +// source_format => target_format: +// Q_K_V_TNH => Q_K_V_BNSH (requires token_offset) +// Q_K_V_TNH => Q_K_V_TNH +// Q_K_V_TNH => QKV_TN3H +// QKV_TN3H => Q_K_V_BNSH (requires token_offset) +// QKV_TN3H => Q_K_V_TNH +// QKV_TN3H => QKV_TN3H +template +void AddBiasTransposePacked( + const T* query, const T* key, const T* value, const T* bias, T* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat source_format, AttentionQkvFormat target_format, + const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + +// Add bias (required) transpose kernel defined in packed_attention_impl.cu. +// Support the following format transforms (for float and half only): +// format transform +// Q_K_V_BNSH: Tx3xNxH => 3xBxNxSxH (requires token_offset) +// Q_K_V_BSNH: Tx3xNxH => 3xTxNxH +// QKV_BSN3H: Tx3xNxH => TxNx3xH +template +void AddBiasTransposePacked( + const T* input, const T* biases, T* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 3b7f980ba188..5c0989bced70 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -260,7 +260,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { fused_runner, use_flash_attention, use_fused_cross_attention, - use_memory_efficient_attention); + use_memory_efficient_attention, + false); IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); typedef typename ToCudaType::MappedType CudaT; @@ -281,6 +282,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workSpaceSize; data.output = reinterpret_cast(output->MutableData()); if (nullptr != present) { data.present = reinterpret_cast(present->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 997493acd9cb..f9eabe27d97e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -58,31 +58,25 @@ size_t AlignSize(size_t bytes) { return bytesAligned; } -void CumulatedSequenceLengthCache::Initialize(int32_t seq_length, cudaStream_t stream) { - if (this->sequence_length != seq_length) { - ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); - LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, - this->max_batch_size, seq_length, stream); - this->sequence_length = seq_length; +const int32_t* CumulatedSequenceLengthCache::TryGet(int batch_size, int32_t seq_len, cudaStream_t stream) { + if (this->sequence_length == 0 && seq_len > 0) { + // Initialize only once with sequence length in the first request. + std::call_once(init_once_flag_, [&]() { + ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); + LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, + this->max_batch_size, seq_len, stream); + // Syncronize to ensure thread-safe since other thread will not wait for the above kernel finish. + // Otherwise, the data might be consumed by other threads before it is ready and causes data race issue. + cudaStreamSynchronize(stream); + this->sequence_length = seq_len; + }); } -} -int* GetCumulatedSequenceLength(CumulatedSequenceLengthCache* cache, - const int* mask_index, - int batch_size, - int sequence_length, - cudaStream_t stream, - void* scratch_buffer) { - if (mask_index == nullptr && cache != nullptr) { - if (batch_size <= cache->max_batch_size) { - cache->Initialize(sequence_length, stream); - return reinterpret_cast(cache->buffer.get()); - } + if (this->sequence_length == seq_len && batch_size <= this->max_batch_size) { + return reinterpret_cast(buffer.get()); } - int* sequence_offset = reinterpret_cast(scratch_buffer); - LaunchTrtSequenceOffset(sequence_offset, mask_index, batch_size, sequence_length, stream); - return sequence_offset; + return nullptr; } size_t GetAttentionScratchSize( @@ -114,10 +108,12 @@ size_t GetAttentionWorkspaceSize( void* fused_runner, bool use_flash_attention, bool use_fused_cross_attention, - bool use_memory_efficient_attention) { + bool use_memory_efficient_attention, + bool no_qkv_workspace) { // Note that q, k and v might need alignment for fused attention kernels. - const size_t qkv_bytes = element_size * batch_size * num_heads * - ((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size); + const size_t qkv_size = element_size * batch_size * num_heads * + ((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size); + const size_t qkv_bytes = no_qkv_workspace ? 0 : qkv_size; #if USE_FLASH_ATTENTION if (use_flash_attention) { @@ -162,39 +158,44 @@ Status FusedTrtCrossAttention( // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. assert(data.mask_index == nullptr); - + assert(data.scratch != nullptr); + assert(data.q != nullptr); + assert(data.k != nullptr); + +#ifndef NDEBUG + char* scratch_end = reinterpret_cast(data.scratch) + 2 * GetSequenceOffsetSize(parameters.batch_size, false); + char* buffer_end = reinterpret_cast(data.workspace) + data.workspace_bytes; + assert(scratch_end <= buffer_end); +#endif const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, - sequence_length, stream, - data.scratch); + int32_t* q_sequence_offset = const_cast(data.cumulated_sequence_length_q_cache); + if (q_sequence_offset == nullptr) { + q_sequence_offset = reinterpret_cast(data.scratch); + LaunchTrtSequenceOffset(q_sequence_offset, data.mask_index, batch_size, sequence_length, stream); + } + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); DUMP_TENSOR_INIT(); DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); - int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); - kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, - data.mask_index, batch_size, parameters.kv_sequence_length, stream, - kv_sequence_offset); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); + int32_t* kv_sequence_offset = const_cast(data.cumulated_sequence_length_kv_cache); + if (kv_sequence_offset == nullptr) { + int* scratch = reinterpret_cast(data.scratch) + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); + kv_sequence_offset = reinterpret_cast(scratch); + LaunchTrtSequenceOffset(kv_sequence_offset, data.mask_index, batch_size, parameters.kv_sequence_length, stream); + } + CUDA_RETURN_IF_ERROR(cudaGetLastError()); DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = reinterpret_cast(data.fused_cross_attention_kernel); - // When there is no bias, we can directly use q and packed kv from inputs. - void const* query = data.q; - void const* packed_kv = data.k; - if (data.value == nullptr && data.bias == nullptr) { - query = data.query; - packed_kv = data.key; - } - run_fused_cross_attention( - query, // Q - packed_kv, // packed KV + data.q, // Q + data.k, // packed KV q_sequence_offset, // cumulated sequence length of Q kv_sequence_offset, // cumulated sequence length of KV data.output, // output @@ -206,8 +207,6 @@ Status FusedTrtCrossAttention( parameters.kv_sequence_length, // sequence length of KV stream); - DUMP_TENSOR("trt cross output", data.output, - batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); return Status::OK(); } @@ -225,24 +224,33 @@ Status FusedTrtSelfAttention( cudaStream_t stream, contrib::AttentionParameters& parameters, AttentionData& data) { + assert(data.scratch != nullptr); +#ifndef NDEBUG + char* scratch_end = reinterpret_cast(data.scratch) + GetSequenceOffsetSize(parameters.batch_size, false); + char* buffer_end = reinterpret_cast(data.workspace) + data.workspace_bytes; + assert(scratch_end <= buffer_end); +#endif + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const bool causal = parameters.is_unidirectional; - int* sequence_offset = reinterpret_cast(data.scratch); - - DUMP_TENSOR_INIT(); + const int32_t* sequence_offset = data.cumulated_sequence_length_q_cache; if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { - DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); - LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); + LaunchTrtSequenceOffset2d(reinterpret_cast(data.scratch), data.mask_index, batch_size, sequence_length, stream); + sequence_offset = reinterpret_cast(data.scratch); } else { - sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, sequence_length, stream, - sequence_offset); + if (sequence_offset == nullptr) { + LaunchTrtSequenceOffset(reinterpret_cast(data.scratch), data.mask_index, batch_size, sequence_length, stream); + sequence_offset = reinterpret_cast(data.scratch); + } } - DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); + FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(data.fused_runner); const int s = causal ? sequence_length : fused_fp16_runner->NormalizeSequenceLength(sequence_length); @@ -252,22 +260,12 @@ Status FusedTrtSelfAttention( if (!causal) { assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H); - - // When there is no bias, we can directly use packed qkv from inputs. - void const* packed_qkv = data.q; - if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { - packed_qkv = data.query; - } - - fused_fp16_runner->Run(b, s, packed_qkv, sequence_offset, data.output, stream); - DUMP_TENSOR("fused output", data.output, - batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + fused_fp16_runner->Run(b, s, data.q, sequence_offset, data.output, stream); } else { assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); fused_fp16_runner->Run(b, s, data.gemm_buffer, sequence_offset, data.output, stream); - DUMP_TENSOR("fused causal output", data.output, - batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); } + return Status::OK(); } @@ -289,38 +287,19 @@ Status FlashAttention( contrib::AttentionParameters& parameters, AttentionData& data, float scale) { - assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); assert(nullptr == data.mask_index); assert(nullptr == data.relative_position_bias); assert(parameters.head_size == parameters.v_head_size); - void* query = reinterpret_cast(data.q); - void* key = reinterpret_cast(data.k); - void* value = reinterpret_cast(data.v); - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { - query = reinterpret_cast(const_cast(data.query)); - } - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("k(BSNH)", data.k, - parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, - parameters.batch_size, parameters.total_sequence_length, - parameters.num_heads, parameters.v_head_size); - - bool is_bf16 = false; + constexpr bool is_bf16 = false; ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), + device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), - true)); - - DUMP_TENSOR("flash attention output", data.output, - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH)); return Status::OK(); } @@ -351,25 +330,8 @@ Status EfficientAttention( float scale) { // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - - const void* query = data.q; - const void* key = data.k; - const void* value = data.v; - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { - assert(data.bias == nullptr); - query = data.query; - } - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("k(BSNH)", data.k, - parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, - parameters.batch_size, parameters.total_sequence_length, - parameters.num_heads, parameters.v_head_size); + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; @@ -394,21 +356,19 @@ Status EfficientAttention( ? nullptr : const_cast(reinterpret_cast( data.mask_index + 2 * parameters.batch_size + 1)); - p.query = query; - p.key = key; - p.value = value; + p.query = data.q; + p.key = data.k; + p.value = data.v; p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; - p.is_kv_bsnh = true; + p.is_kv_bsnh = data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH; p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) ? data.scratch : nullptr; p.stream = stream; p.has_custom_right_padding = false; run_memory_efficient_attention(p); - DUMP_TENSOR("efficient attention output", data.output, - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); return Status::OK(); } @@ -449,10 +409,6 @@ Status UnfusedAttention( cublasSetStream(cublas, stream); - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("q[BNSH]", data.q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k[BNSH]", data.k, batch_size, num_heads, total_sequence_length, qk_head_size); - const int present_sequence_length = parameters.past_present_share_buffer ? parameters.max_sequence_length : total_sequence_length; @@ -467,8 +423,7 @@ Status UnfusedAttention( &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop, parameters.use_tf32)); - DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length); + DUMP_TENSOR_INIT(); DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length); constexpr size_t element_size = sizeof(T); @@ -523,7 +478,6 @@ Status UnfusedAttention( // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, device_prop.maxThreadsPerBlock, false, temp_output, data.output); - DUMP_TENSOR("unfused output", data.output, batch_size, sequence_length, num_heads, v_head_size); return result; } @@ -554,7 +508,7 @@ Status QkvToContext( if (!parameters.past_present_share_buffer) { ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, - sequence_length, total_sequence_length, parameters.pass_past_in_kv, + sequence_length, total_sequence_length, stream, max_threads_per_block, data)); } else { // past_present_share_buffer diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 56836bdda197..fad353dcfeb0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include "core/framework/allocator.h" #include "contrib_ops/cpu/bert/attention_common.h" @@ -15,13 +17,18 @@ namespace cuda { constexpr int kCumulatedSequenceLengthCacheMaxBatchSize = 128; +// A cache for cumulated sequence length. It will be initialized in the first request, then become read-only after that. struct CumulatedSequenceLengthCache { onnxruntime::IAllocatorUniquePtr buffer; int32_t max_batch_size; int32_t sequence_length; - CumulatedSequenceLengthCache() : max_batch_size(0), sequence_length(0) {} - void Initialize(int32_t sequence_length, cudaStream_t stream); + CumulatedSequenceLengthCache() : max_batch_size(kCumulatedSequenceLengthCacheMaxBatchSize), sequence_length(0) {} + + const int32_t* TryGet(int batch_size, int32_t sequence_length, cudaStream_t stream); + + // Use this flag to guard the initializaton only once in multi-threading. + mutable std::once_flag init_once_flag_; }; size_t @@ -46,7 +53,8 @@ size_t GetAttentionWorkspaceSize( void* fused_runner, bool use_flash_attention, bool use_fused_cross_attention, - bool use_memory_efficient_attention); + bool use_memory_efficient_attention, + bool no_qkv_workspace); template struct AttentionData { @@ -65,8 +73,6 @@ struct AttentionData { bool has_qkv_workspace = false; T* workspace = nullptr; - T* temp_k_workspace = nullptr; - T* temp_v_workspace = nullptr; T* output = nullptr; T* present = nullptr; @@ -79,22 +85,50 @@ struct AttentionData { bool use_flash_attention = false; bool use_memory_efficient_attention = false; - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr; - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr; + const int32_t* cumulated_sequence_length_q_cache = nullptr; + const int32_t* cumulated_sequence_length_kv_cache = nullptr; // Intermediate data T* q = nullptr; T* k = nullptr; T* v = nullptr; T* scratch = nullptr; - AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + AttentionQkvFormat qkv_format = AttentionQkvFormat::UNKNOWN; // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; + + // For Debugging + size_t workspace_bytes = 0; + bool allow_debug_info = false; + + bool IsUnfused() const { + return !use_flash_attention && !use_memory_efficient_attention && + (fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr); + } + + void PrintDebugInfo() const { + std::cout << "flash=" << use_flash_attention + << ", efficient=" << use_memory_efficient_attention + << ", fused_runner=" << (fused_runner != nullptr) + << ", fused_cross=" << (fused_cross_attention_kernel != nullptr) + << ", bias=" << (bias != nullptr) + << ", attn_bias=" << (relative_position_bias != nullptr) + << ", mask_dims=" << mask_index_dims.size() + << ", has_qkv_workspace=" << has_qkv_workspace + << ", workspace=" << workspace_bytes + << ", past=" << (past != nullptr ? 1 : (past_key != nullptr ? 2 : 0)) + << ", present=" << (present != nullptr ? 1 : (present_key != nullptr ? 2 : 0)) + << std::endl; + } }; +// Return true if it does not need qkv workspace, false otherwise. +template +bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data); + template Status PrepareQkv(contrib::AttentionParameters& parameters, AttentionData& data, @@ -129,6 +163,9 @@ Status LaunchTransQkv(cudaStream_t stream, const int matrix_num, const int max_threads_per_block, const bool reversed_bs, const half* input, half* output, int total_matrix_count = -1); +Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const float* input, float* output, cudaStream_t stream, const int max_threads_per_block); + Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, const half* input, half* output, cudaStream_t stream, const int max_threads_per_block); @@ -158,7 +195,7 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, + int sequence_length, int total_sequence_length, cudaStream_t stream, int max_threads_per_block, AttentionData& data); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h index bd7df5f490c7..aba1e01bfd91 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h @@ -50,6 +50,7 @@ class AttentionKernelOptions { bool use_unfused_{true}; bool use_trt_flash_attention_{true}; + bool use_trt_cross_attention_{true}; // Causal attention is disabled by default in #14732. diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu index 89be0f1115f4..9f0f49348c22 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu @@ -249,16 +249,15 @@ Status LaunchConcatPastToPresent(cudaStream_t stream, template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, - cudaStream_t stream, - int max_threads_per_block, + int sequence_length, int total_sequence_length, + cudaStream_t stream, int max_threads_per_block, AttentionData& data) { // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) // When there is past state, the head size for Q/K/V shall be same: H == H_v. - if (nullptr != data.present) { + if (nullptr != data.present) { // Attention op assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); @@ -270,58 +269,52 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int // Update pointers to present_k and present_v. data.k = data.present; data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; - } else if (nullptr != data.past_key || nullptr != data.present_key) { - if (nullptr != data.past_key && nullptr == data.present_key) { - data.k = const_cast(data.past_key); - data.v = const_cast(data.past_value); - } else if (nullptr == data.past_key && nullptr != data.present_key) { - if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + } else { // MultiHeadAttention op + if (nullptr != data.present_key) { + ORT_ENFORCE(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + if (nullptr != data.past_key) { + assert(data.past_key != data.k); + assert(data.past_value != data.v); + + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, qk_head_size, num_heads, + max_threads_per_block, 1, data.past_key, data.k, data.present_key)); + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, v_head_size, num_heads, + max_threads_per_block, 1, data.past_value, data.v, data.present_value)); + // Update pointers to present_k and present_v. data.k = data.present_key; data.v = data.present_value; - } else { - assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - data.k = data.temp_k_workspace; - data.v = data.temp_v_workspace; + } else { // nullptr == data.past_key && nullptr != data.present_key + if (data.k != data.present_key) { + int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size; + cudaMemcpyAsync(data.present_key, data.k, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + } + + if (data.v != data.present_value) { + int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size; + cudaMemcpyAsync(data.present_value, data.v, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + } } - } else if (pass_past_in_kv) { - // past_key and past_value are used directly as key and value in attention computations - data.k = const_cast(data.past_key); - data.v = const_cast(data.past_value); - - // This path has a memory copy from past_key and past_value to present_key and present_value - // Avoid this path since the memory copy is unnecessary because past_key == present_key and - // past_value == present_value - int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size; - int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size; - cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - } else { - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, - batch_size, qk_head_size, num_heads, - max_threads_per_block, 1, data.past_key, data.k, data.present_key)); - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, - batch_size, v_head_size, num_heads, - max_threads_per_block, 1, data.past_value, data.v, data.present_value)); - // Update pointers to present_k and present_v. - data.k = data.present_key; - data.v = data.present_value; } } + return CUDA_CALL(cudaGetLastError()); } // Template Instantiation template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, + int sequence_length, int total_sequence_length, cudaStream_t stream, int max_threads_per_block, AttentionData& data); template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, + int sequence_length, int total_sequence_length, cudaStream_t stream, int max_threads_per_block, AttentionData& data); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index 040d6124e745..05c592ec6105 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -12,12 +12,101 @@ namespace onnxruntime { namespace contrib { namespace cuda { +#if DEBUG_TENSOR_LEVEL > 1 +// Dump the workspace for Q, K, V after processing QKV data. +template +void DumpQkv(AttentionData& data) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + DUMP_TENSOR_INIT(); + if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + DUMP_TENSOR_D("q(BNSH)", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", data.k, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", data.v, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { + DUMP_TENSOR_D("q(BSNH)", data.q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, kv_sequence_length, num_heads, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH) { + DUMP_TENSOR_D("q(BNSH)", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", data.k, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", data.v, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::QKV_BSN3H) { + DUMP_TENSOR_D("q(BSN3H)", data.q, batch_size, sequence_length, num_heads * 3, qk_head_size); + } +} + +// Dump the inputs before processing QKV data. +template +void DumpInputs(contrib::AttentionParameters& parameters, AttentionData& data) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + DUMP_TENSOR_INIT(); + if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { + DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("Key(BNSH)", data.key, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("Value(BNSH)", data.value, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { + DUMP_TENSOR_D("Query(BSNH)", data.query, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("Key(BSNH)", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("Value(BSNH)", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH) { + DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("Key(BNSH)", data.key, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("Value(BNSH)", data.value, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::QKV_BSN3H) { + DUMP_TENSOR_D("Query(BSN3H)", data.query, batch_size, sequence_length, num_heads * 3, qk_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H) { + DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("Value(BSN2H)", data.value, batch_size, sequence_length, num_heads * 2, qk_head_size); + } + + if (data.bias != nullptr) { + DUMP_TENSOR_D("Q_bias", data.bias, num_heads, qk_head_size); + DUMP_TENSOR_D("K_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); + DUMP_TENSOR_D("V_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); + } + + if (data.relative_position_bias != nullptr) { + DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, + parameters.broadcast_res_pos_bias ? 1 : batch_size, + num_heads, sequence_length, kv_sequence_length); + } + + if (data.mask_index != nullptr) { + if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { + DUMP_TENSOR_D("mask", data.mask_index, batch_size, parameters.total_sequence_length); + } + if (parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { + DUMP_TENSOR_D("mask", data.mask_index, 3 * batch_size + 2, 1); + } + } +} + +// Dump the kernel outputs +template +void DumpOutputs(AttentionData& data) { + DUMP_TENSOR_INIT(); + DUMP_TENSOR("output", data.output, + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); +} +#endif + template Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { + int max_threads_per_block) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; @@ -40,7 +129,7 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, int matrix_to_trans = (past_present_share_buffer ? 1 : 3); ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, max_threads_per_block, false, data.gemm_buffer, qkv, 3)); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } else { // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) @@ -48,13 +137,13 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, // For fused causal kernel, use format 1 since we need have K and V to update present state, // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); - qkv_format = use_fused_kernel - ? AttentionQkvFormat::QKV_BSN3H - : (use_flash_or_efficient_attention - ? AttentionQkvFormat::Q_K_V_BSNH - : (use_fused_causal - ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH - : AttentionQkvFormat::Q_K_V_BNSH)); + data.qkv_format = use_fused_kernel + ? AttentionQkvFormat::QKV_BSN3H + : (use_flash_or_efficient_attention + ? AttentionQkvFormat::Q_K_V_BSNH + : (use_fused_causal + ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH + : AttentionQkvFormat::Q_K_V_BNSH)); // For fused causal, we will update gemm_buffer with bias directly. T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; @@ -71,367 +160,526 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, return Status::OK(); } -// For MultiHeadAttention with past state +// Return true if the workspace is not needed for Q, K, V inputs, false otherwise. +// This shall be in sync with the following function PrepareQkv_MHA_Cross. template -Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { +bool NoQkvWorkspace_MHA_Cross(AttentionData& data) { + // query, key and value are passed as Q, K and V for the following conditions. + return (data.use_memory_efficient_attention || data.use_flash_attention) && (data.bias == nullptr); +} + +// For MultiHeadAttention with cross attention (Q_K_V_BSNH_BNSH_BNSH format) +template +Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + // past_key or past_value is not supported for cross attention + // present_key and present_value can be supported in theory, although we do not allow the senario for now. + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_Cross(data)); + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - - DUMP_TENSOR_INIT(); - if (data.bias == nullptr) { - // Below logic does not support fused attention with past without bias - // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past. - - // cross attention with past state - if (data.past_key != nullptr && data.present_key == nullptr) { - assert(data.past_value != nullptr); - assert(data.query != nullptr); - assert(data.key == nullptr); - assert(data.value == nullptr); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Add bias for Q + if (data.bias != nullptr) { + LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, + data.bias, data.query, data.q); + } else { + data.q = const_cast(data.query); } - // cross attention with present state or self attention with present state - else if (data.past_key == nullptr && data.present_key != nullptr) { - assert(data.past_value == nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - - // TODO: supporting packed qkv for self attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - // TODO: supporting packed kv for cross attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, data.present_key)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, data.present_value)); - } - // self attention with past and present state - else { - assert(data.past_key != nullptr); - assert(data.past_value != nullptr); - assert(data.present_key != nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - // TODO: supporting packed qkv for self attention may benefit performance + // Here we have assumption that there is no bias for key and value when they are in BNSH format. + data.k = const_cast(data.key); + data.v = const_cast(data.value); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + } else +#endif + { // unfused kernel + assert(data.IsUnfused()); + if (data.bias == nullptr) { + // Transpose query from BSNH to BNSH ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, k)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, v)); + max_threads_per_block, false, data.query, data.q)); + } else { + // Add bias to query, and transpose it: Query (BxSxNxH) => Q (BxNxSxH) + constexpr int format = 0; + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, data.q, + true, -1); } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + + // Here we have assumption that there is no bias for key and value when they are in BNSH format. + // So we do not need to add bias for key and value. Just use the key and value directly. + data.k = const_cast(data.key); + data.v = const_cast(data.value); + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + return Status::OK(); +} + +template +bool NoQkvWorkspace_MHA_NoPast(AttentionData& data) { + // query, key and value are passed as Q, K and V for the following conditions. + return (data.use_memory_efficient_attention || data.use_flash_attention) && data.bias == nullptr; +} + +// For MultiHeadAttention without past state, with Q, K and V inputs +template +Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.present_key == nullptr); + assert(data.present_value == nullptr); + assert(!parameters.is_unidirectional); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + if (data.fused_cross_attention_kernel != nullptr) { + assert(qk_head_size == v_head_size); + assert(data.relative_position_bias == nullptr); + assert(data.mask_index == nullptr); + assert(parameters.hidden_size == parameters.v_hidden_size); + + // For fused cross attention, besides adding bias, K and V needed to be packed: + // Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length); + data.v = nullptr; + data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; } #if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value - else if ((data.use_memory_efficient_attention || data.use_flash_attention) && - data.past_key != nullptr && - data.past_value != nullptr && - parameters.pass_past_in_kv) { - // Transpose past_key and past_value to use memory efficient attention - - // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_key, data.temp_k_workspace)); - // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_value, data.temp_v_workspace)); - - // query => q, temp_k_workspace => k, temp_v_workspace => v - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - - data.past_key = nullptr; - data.past_value = nullptr; + else if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.bias != nullptr) { + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, data.q, data.k, data.v); + } else { + data.q = const_cast(data.query); + data.k = const_cast(data.key); + data.v = const_cast(data.value); + } + + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; } - // When there is no past_key/past_value and there is present_key/present_value - // (e.g. get initial kv to use as past_kv in the next iteration) - else if ((data.use_memory_efficient_attention || data.use_flash_attention) && - data.present_key != nullptr && - data.present_value != nullptr) { - // Use memory efficient attention kernel - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace); - - // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH) - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.temp_k_workspace, data.present_key)); +#endif + else if (data.fused_runner != nullptr) { + assert(qk_head_size == v_head_size); + assert(data.relative_position_bias == nullptr); + + // Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length); + data.k = nullptr; + data.v = nullptr; + + data.qkv_format = AttentionQkvFormat::QKV_BSN3H; + } else { // unfused kernel + assert(data.IsUnfused()); + // Query (BxSxNxH) => Q (BxNxSxH) + constexpr int format = 0; + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, data.q, + true, -1); + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k, + true, -1); - // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v) + // Value (BxLxNxH_v) => K (BxNxLxH_v) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v, + true, -1); + + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + + return Status::OK(); +} + +template +bool NoQkvWorkspace_MHA_WithPast_NoBias(AttentionData& data) { + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Q, K and V redirects to query, present_k and present_v, so we do not need extra workspace for QKV. + return data.past_key == nullptr && data.present_key != nullptr; + } + return false; +} + +// For MultiHeadAttention with kv cache (past or present), but no bias +template +Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + assert(data.bias == nullptr); + assert(data.fused_runner == nullptr); + assert(data.fused_cross_attention_kernel == nullptr); + assert(data.present_key != nullptr); + assert(data.present_value != nullptr); + assert(data.past_key == nullptr && data.past_value == nullptr || + data.past_key != nullptr && data.past_value != nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_NoBias(data)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // When there is no past state and there is present state, we output K and V directly to present state. + if (data.past_key == nullptr && data.present_key != nullptr) { + data.k = data.present_key; + data.v = data.present_value; + } + +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Use oiginal Query (BSNH) since there is no bias. + data.q = const_cast(data.query); + + // Key (BxLxNxH) => K (BxNxLxH) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.k)); + // Value (BxLxNxH) => V (BxNxLxH) ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.temp_v_workspace, data.present_value)); + max_threads_per_block, false, data.value, data.v)); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + } else +#endif + { // unfused kernel + assert(data.IsUnfused()); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, data.q)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.k)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, data.v)); + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + + return Status::OK(); +} - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; +template +constexpr bool NoQkvWorkspace_MHA_WithPast_Bias(AttentionData& /*data*/) { + return false; +} + +// For MultiHeadAttention with both kv cache (past or present) and bias +template +Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.bias != nullptr); + assert(!(data.past_key != nullptr && data.present_key == nullptr)); + assert(data.fused_runner == nullptr); + assert(data.fused_cross_attention_kernel == nullptr); + assert(data.present_key != nullptr); + assert(data.present_value != nullptr); + assert(data.past_key == nullptr && data.past_value == nullptr || + data.past_key != nullptr && data.past_value != nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_Bias(data)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // When there is no past state and there is present state, we output K and V directly to present state. + if (data.past_key == nullptr && data.present_key != nullptr) { + data.k = data.present_key; + data.v = data.present_value; } + +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Query(BxSxNxH) + Bias_Q => Q (BxSxNxH) + LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, + data.bias, data.query, data.q); + + // Key (BxLxNxH) + Bias_K => K (BxNxLxH) + constexpr int format = 0; + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, data.bias + num_heads * qk_head_size, data.k, true, -1); + + // Key (BxLxNxH) + Bias_K => K (BxNxLxH) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, data.bias + 2 * num_heads * qk_head_size, data.v, true, -1); + + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + } else #endif - else { - // Use unfused kernel for Q, use unfused kernel for K and V if needed + { // unfused kernel + assert(data.IsUnfused()); + constexpr int format = 0; // Query (BxSxNxH) => Q (BxNxSxH) LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, + data.query, data.bias, data.q, true, -1); - if (!parameters.pass_past_in_kv) { - T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k; - T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v; - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, data.bias + num_heads * qk_head_size, k_dest, - true, -1); + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, data.bias + num_heads * qk_head_size, data.k, + true, -1); - // Value (BxLxNxH_v) => V (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, data.bias + 2 * num_heads * qk_head_size, v_dest, - true, -1); + // Value (BxLxNxH_v) => V (BxNxLxH_v) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, data.bias + 2 * num_heads * qk_head_size, data.v, + true, -1); - DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size); - } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } + return Status::OK(); } +template +bool NoQkvWorkspace_MHA_PackedQKV(AttentionData& data) { + // query, key and value are passed as Q, K and V for the following conditions. + return nullptr != data.fused_runner && data.bias == nullptr; +} + // For MultiHeadAttention without past state, with packed QKV inputs template Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::QKV_BSN3H); + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.present_key == nullptr); + assert(data.present_value == nullptr); + assert(parameters.head_size == parameters.v_head_size); + assert(data.fused_cross_attention_kernel == nullptr); + assert(!parameters.is_unidirectional); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_PackedQKV(data)); + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - void* fused_runner = data.fused_runner; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); if (data.use_memory_efficient_attention || data.use_flash_attention) { - // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. + // unpack qkv to BSNH. constexpr int format = 4; T* qkv_add_bias = nullptr; LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, qkv, + data.query, data.bias, data.q, true, v_head_size, qkv_add_bias, 3); - DUMP_TENSOR_D("q(BSNH)", data.q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (!use_fused_kernel) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, NOT_IMPLEMENTED, - "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else if (nullptr != data.fused_runner) { + assert(nullptr == data.relative_position_bias); + if (data.bias == nullptr) { + // When there is no bias, we can directly use the original packed QKV input. + // Need revisit this when we add support for causal. + data.q = const_cast(data.query); + data.k = nullptr; + data.v = nullptr; + } else { // data.bias != nullptr + AddBiasTransposePacked( + data.query, data.key, data.value, data.bias, data.q, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + AttentionQkvFormat::QKV_TN3H, AttentionQkvFormat::QKV_TN3H, + nullptr, batch_size * sequence_length, + stream); } - qkv_format = AttentionQkvFormat::QKV_BSN3H; + data.qkv_format = AttentionQkvFormat::QKV_BSN3H; + } else { // unfused kernel + assert(data.IsUnfused()); + // unpack qkv to BNSH + constexpr int format = 5; + T* qkv_add_bias = nullptr; + LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, data.q, + true, v_head_size, qkv_add_bias, 3); + + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } + return Status::OK(); } +// This shall be in sync with the following function PrepareQkv_MHA_PackedQKV. +template +bool NoQkvWorkspace_MHA_PackedKV(AttentionData& data) { + return data.fused_cross_attention_kernel != nullptr; +} + // For MultiHeadAttention without past state, with packed KV inputs template Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); + assert(data.bias == nullptr); + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.present_key == nullptr); + assert(data.present_value == nullptr); + assert(parameters.head_size == parameters.v_head_size); + assert(data.fused_runner == nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_PackedKV(data)); + const int batch_size = parameters.batch_size; const int kv_sequence_length = parameters.kv_sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. - // CheckInputs verified this constraint. - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); - if (data.use_memory_efficient_attention || data.use_flash_attention) { - // unpack kv to BSNH. Note that there is no bias so we need not output query to q. + // Note that there is no bias so we need not output query to q. + data.q = const_cast(data.query); + // Unpack kv to BSNH. constexpr int format = 4; T* qkv_add_bias = nullptr; const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, batch_size, kv_sequence_length, num_heads, qk_head_size, data.key, kv_bias, data.k, - true, v_head_size, qkv_add_bias, 2); - DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (data.fused_cross_attention_kernel == nullptr) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, NOT_IMPLEMENTED, - "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); - } + true, v_head_size, qkv_add_bias); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else if (data.fused_cross_attention_kernel != nullptr) { + data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + data.q = const_cast(data.query); + data.k = const_cast(data.key); + data.v = nullptr; + } else { // unfused kernel + assert(data.IsUnfused()); + // Transpose q from BSNH to BNSH. Note that there is no bias. + ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(batch_size, parameters.sequence_length, num_heads, qk_head_size, + data.query, data.q, stream, max_threads_per_block)); - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + // Unpack kv to BNSH. + constexpr int format = 5; + T* qkv_add_bias = nullptr; + const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); + LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, kv_bias, data.k, + true, v_head_size, qkv_add_bias, 2); + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } + return Status::OK(); } -// For MultiHeadAttention without past state, with Q, K and V inputs +// Prepare Q, K and V for MultiHeadAttention operator. template -Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - void* fused_runner = data.fused_runner; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - // gemm_buffer == nullptr and not packed - assert(data.query != nullptr && data.key != nullptr && data.value != nullptr); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); - -#if DUMP_TENSOR_LEVEL > 1 - if (data.bias != nullptr) { - DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); - DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); - DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); - } -#endif - - if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { - DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, - num_heads, sequence_length, kv_sequence_length); - } - - if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { - DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); - } - - if (data.fused_cross_attention_kernel != nullptr) { - assert(qk_head_size == v_head_size); - - // For fused cross attention, besides adding bias, K and V needed to be packed: - // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); - - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - else if (data.use_memory_efficient_attention || data.use_flash_attention) { - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; +Status PrepareQkv_MultiHeadAttention(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + switch (parameters.qkv_format) { + case AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH: + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_Cross(parameters, data, stream, max_threads_per_block)); + break; + case AttentionQkvFormat::Q_KV_BSNH_BSN2H: + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block)); + break; + case AttentionQkvFormat::QKV_BSN3H: + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block)); + break; + case AttentionQkvFormat::Q_K_V_BSNH: + if (data.past_key != nullptr || data.present_key != nullptr) { + if (data.bias == nullptr) { + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block)); + } else { + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_Bias(parameters, data, stream, max_threads_per_block)); + } + } else { // no past state + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block)); + } + break; + default: + ORT_THROW("Unsupported QKV format: ", parameters.qkv_format); } -#endif - else if (use_fused_kernel) { - assert(qk_head_size == v_head_size); - - // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); - DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); - - qkv_format = AttentionQkvFormat::QKV_BSN3H; - } else { // unfused kernel - ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal"); - - // Query (BxSxNxH) => Q (BxNxSxH) - constexpr int format = 0; - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, - true, -1); - - // Value (BxLxNxH_v) => K (BxNxLxH_v) - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, - true, -1); + return Status::OK(); +} - DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; +// Check whether there is no needed to have workspace for Q, K and V for MultiHeadAttention operator. +// Please make it in sync with PrepareQkv_MultiHeadAttention. +template +bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data) { + switch (parameters.qkv_format) { + case AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH: + return NoQkvWorkspace_MHA_Cross(data); + case AttentionQkvFormat::Q_KV_BSNH_BSN2H: + return NoQkvWorkspace_MHA_PackedKV(data); + case AttentionQkvFormat::QKV_BSN3H: + return NoQkvWorkspace_MHA_PackedQKV(data); + case AttentionQkvFormat::Q_K_V_BSNH: + if (data.past_key != nullptr || data.present_key != nullptr) { + if (data.bias == nullptr) { + return NoQkvWorkspace_MHA_WithPast_NoBias(data); + } else { + return NoQkvWorkspace_MHA_WithPast_Bias(data); + } + } else { // no past state + return NoQkvWorkspace_MHA_NoPast(data); + } + default: + ORT_THROW("Unsupported QKV format: ", parameters.qkv_format); } - return Status::OK(); } template @@ -439,7 +687,6 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block) { - data.scratch = data.workspace; if (data.has_qkv_workspace) { const int size_per_batch_q = parameters.sequence_length * parameters.head_size; const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size; @@ -452,28 +699,37 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, data.k = data.workspace + elements_q; data.v = data.k + elements_k; data.scratch = data.v + elements_v; + } else { + data.q = nullptr; + data.k = nullptr; + data.v = nullptr; + data.scratch = data.workspace; } +#if DEBUG_TENSOR_LEVEL > 1 + DumpInputs(parameters, data); +#endif + if (nullptr != data.gemm_buffer) { // Attention operator - ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, - data.qkv_format)); - } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, - data.q, data.k, data.v, data.qkv_format)); - } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, data.qkv_format)); - } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, data.qkv_format)); - } else { // multihead attention operator, no past, separated Q/K/V inputs - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, - data.q, data.k, data.v, data.qkv_format)); + ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block)); + } else { // MultiHeadAttention operator + ORT_RETURN_IF_ERROR(PrepareQkv_MultiHeadAttention(parameters, data, stream, max_threads_per_block)); } + assert(data.qkv_format != AttentionQkvFormat::UNKNOWN); + +#if DEBUG_TENSOR_LEVEL > 1 + DumpQkv(data); +#endif + CUDA_RETURN_IF_ERROR(cudaGetLastError()); return Status::OK(); } // Template Instantiation +template bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data); +template bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data); + template Status PrepareQkv( contrib::AttentionParameters& parameters, AttentionData& data, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu index bd38a21aadfc..9f3e396b7f94 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu @@ -304,6 +304,12 @@ Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, c max_threads_per_block, false, input, output); } +Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const float* input, float* output, cudaStream_t stream, const int max_threads_per_block) { + return LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, false, input, output); +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index 66c0aceaed1e..037a4fdf3d9a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -75,7 +75,6 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* attention::kDecoderMaskedAttentionLoadKVDataInFlight, false); bool is_unidirectional = false; - bool is_dmmha_packing = (key == nullptr && value == nullptr); ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, @@ -91,7 +90,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* scale_, is_unidirectional, past_present_share_buffer_, - is_dmmha_packing, // dmmha_packing + kDecoderMaskedMultiHeadAttention, device_prop.maxThreadsPerBlock)); if (bias) { @@ -157,7 +156,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* parameters.is_cross_attention = true; parameters.total_sequence_length = parameters.kv_sequence_length; parameters.max_sequence_length = parameters.kv_sequence_length; - // parameters.k and paraneters.v are nullptr + // parameters.k and parameters.v are nullptr parameters.k_cache = const_cast(key->Data()); parameters.v_cache = const_cast(value->Data()); parameters.k_bias = nullptr; @@ -188,12 +187,14 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* } parameters.is_cross_attention = false; - parameters.is_packed_qkv = is_dmmha_packing; - parameters.k = is_dmmha_packing + bool is_packed_qkv = (key == nullptr && value == nullptr); + parameters.is_packed_qkv = is_packed_qkv; + + parameters.k = is_packed_qkv ? const_cast(query->Data() + parameters.hidden_size) : const_cast(key->Data()); - parameters.v = is_dmmha_packing + parameters.v = is_packed_qkv ? const_cast(query->Data() + 2 * static_cast(parameters.hidden_size)) : const_cast(value->Data()); parameters.k_cache = present_key_data; diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 9efb6f08e8e9..2f8d277cb734 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -183,6 +183,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; } + // This has assumption that key and value does not have bias for cross attention when they are in BNSH format. if (!params.is_cross_attention) { Qk_vec_k k; @@ -580,6 +581,8 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // One group of threads computes the product(s) for the current timestep. V_vec_k v_bias; + + // This has assumption that key and value does not have bias for cross attention when they are in BNSH format. if (params.v_bias && !params.is_cross_attention) { zero(v_bias); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 90f0b94cafce..967c04c52b18 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -92,6 +92,11 @@ void set_params_fprop(Flash_fwd_params& params, params.softmax_lse_ptr = softmax_lse_d; // Set the dimensions. +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4267) // Ignore conversion from 'size_t' to 'int', possible loss of data +#pragma warning(disable : 4244) // Ignore conversion from 'double' to 'float', possible loss of data +#endif params.b = batch_size; params.h = num_heads; params.h_k = num_heads_k; @@ -119,6 +124,9 @@ void set_params_fprop(Flash_fwd_params& params, if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif params.window_size_left = window_size_left; params.window_size_right = window_size_right; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 663bd020ddac..c36abc8e1d62 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -7,6 +7,7 @@ #include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -44,7 +45,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead."); + ORT_ENFORCE(!is_unidirectional_, + "MHA support CUDA kernel does not Unidirectional. Consider using Attention or GQA instead."); kernel_options_ = this->GetAttentionKernelOptions(); @@ -95,7 +97,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { scale_, is_unidirectional_, false, // past_present_share_buffer - false, // dmmha_packing + kMultiHeadAttention, device_prop.maxThreadsPerBlock)); int sequence_length = parameters.sequence_length; @@ -111,25 +113,43 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present_key = context->Output(1, present_shape); Tensor* present_value = context->Output(2, present_shape); - MHARunner* fused_runner = nullptr; + int num_past = static_cast(past_key != nullptr) + static_cast(past_value != nullptr); + int num_present = static_cast(present_key != nullptr) + static_cast(present_value != nullptr); + if (num_past == 0 && num_present == 0) { + // It is valid case without past state. + } else if ((num_past == 2 && num_present == 2) || (num_past == 0 && num_present == 2)) { + if (parameters.qkv_format == AttentionQkvFormat::QKV_BSN3H) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be empty for packed QKV format"); + } + + if (parameters.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be empty for packed KV format"); + } + + if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be empty for cross attention"); + } + } else { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be all provided, " + "or all empty, or only present_key and present_value are provided"); + } + MHARunner* fused_runner = nullptr; const FusedMultiHeadCrossAttentionKernel* fused_cross_attention_kernel = nullptr; // Check whether we can use fused kernel int sm = device_prop.major * 10 + device_prop.minor; - bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - - const bool pass_key_value_as_past = (parameters.pass_past_in_kv && nullptr != key && nullptr != value); - -#if USE_FLASH_ATTENTION || USE_MEMORY_EFFICIENT_ATTENTION - // Exclude this case since PrepareQkv will convert the format to BNSH. - bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; -#endif - #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && - !past_no_bias && nullptr == relative_position_bias && nullptr == key_padding_mask && parameters.head_size == parameters.v_head_size && @@ -138,7 +158,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.num_heads, parameters.num_heads); // When input is packed QKV format, TensorRT kernel might be faster than flash attention when sequence length <= 512. - if (use_flash_attention && key == nullptr && value == nullptr && + if (use_flash_attention && parameters.qkv_format == AttentionQkvFormat::QKV_BS3NH && parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } @@ -162,19 +182,21 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif - bool use_fused_cross_attention = !use_flash_attention && - !disable_fused_cross_attention_ && - nullptr == key_padding_mask && - nullptr == relative_position_bias && - (nullptr == past_key && nullptr == past_value && !parameters.pass_past_in_kv) && - key != nullptr && - (value != nullptr || bias == nullptr) && // TODO: new kernel for adding bias to packed KV - parameters.hidden_size == parameters.v_hidden_size && - has_fused_cross_attention_kernel(sm, parameters.head_size, - parameters.kv_sequence_length); + bool use_fused_cross_attention = + !use_flash_attention && + !disable_fused_cross_attention_ && + nullptr == key_padding_mask && + nullptr == relative_position_bias && + nullptr == past_key && nullptr == present_key && + (parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) && + parameters.hidden_size == parameters.v_hidden_size && + has_fused_cross_attention_kernel(sm, parameters.head_size, + parameters.kv_sequence_length); if (use_fused_cross_attention) { if (fused_fp16_cross_attention_kernel_ == nullptr) { - fused_fp16_cross_attention_kernel_ = get_fused_cross_attention_kernels(sm); + std::call_once(fused_cross_init_once_flag_, [&]() { + fused_fp16_cross_attention_kernel_ = get_fused_cross_attention_kernels(sm); + }); } // In case some kernel not loaded due to shared memory limit, we need to double check here. @@ -184,17 +206,18 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } } - bool use_fused_runner = !use_flash_attention && - !disable_fused_self_attention_ && - fused_cross_attention_kernel == nullptr && - nullptr == relative_position_bias && - (value != nullptr || key == nullptr) && - (nullptr == past_key && nullptr == past_value && !parameters.pass_past_in_kv) && - (nullptr == key_padding_mask || is_mask_1d_seq_len) && - parameters.hidden_size == parameters.v_hidden_size && - parameters.sequence_length == parameters.kv_sequence_length && - FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, false); + bool use_fused_runner = + !use_flash_attention && + !disable_fused_self_attention_ && + fused_cross_attention_kernel == nullptr && + nullptr == relative_position_bias && + (parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) && + nullptr == past_key && nullptr == present_key && + (nullptr == key_padding_mask || AttentionMaskType::MASK_1D_KEY_SEQ_LEN) && + parameters.hidden_size == parameters.v_hidden_size && + parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner + FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, false); if (use_fused_runner) { // Here we assume that num_heads and head_size does not change for a MultiHeadAttention node. if (nullptr == fused_fp16_runner_.get()) { @@ -214,10 +237,11 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { #if USE_MEMORY_EFFICIENT_ATTENTION int length_threshold = this->kernel_options_->MinSeqLenForEfficientAttentionFp32(); - bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16 + bool is_long_sequence = std::is_same::value || // sequence length threshold is 0 for FP16 parameters.sequence_length >= length_threshold || parameters.kv_sequence_length >= length_threshold; + // Check whether the relative position bias alignment is good for memory efficient attention. bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; bool use_memory_efficient_attention = @@ -226,82 +250,25 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { fused_cross_attention_kernel == nullptr && !disable_memory_efficient_attention_ && is_long_sequence && - !past_no_bias && (relative_position_bias == nullptr || is_good_for_rpb) && (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && - has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); + has_memory_efficient_attention(sm, std::is_same::value, + parameters.head_size, parameters.v_head_size); #else constexpr bool use_memory_efficient_attention = false; #endif - if (kernel_options_->AllowDebugInfo()) { - AttentionKernelDebugInfo debug_info; - debug_info.use_flash_attention = use_flash_attention; - debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; - debug_info.use_efficient_attention = use_memory_efficient_attention; - if (fused_fp16_runner_ != nullptr) { - debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length); - } - - debug_info.Print("MultiHeadAttention", - this->Node().Name(), - std::is_same::value, - std::is_same::value); - } - - // When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace. - // TODO(tianleiwu): flash attention or memory efficient attention might not need qkv workspace sometime. - bool no_qkv_workspace = nullptr == value && - (use_fused_cross_attention || (nullptr != fused_runner && nullptr == key)) && - nullptr == key_padding_mask && - nullptr == bias; - - size_t workspace_bytes; - constexpr size_t element_size = sizeof(T); - if (no_qkv_workspace) { - workspace_bytes = (parameters.batch_size > kCumulatedSequenceLengthCacheMaxBatchSize) ? 2 * GetSequenceOffsetSize(parameters.batch_size, true) : 0; - } else { - workspace_bytes = GetAttentionWorkspaceSize(element_size, - parameters.batch_size, - parameters.num_heads, - parameters.head_size, - parameters.v_head_size, - parameters.sequence_length, - parameters.kv_sequence_length, - parameters.total_sequence_length, - fused_runner, - use_flash_attention, - use_fused_cross_attention, - use_memory_efficient_attention); - } - - auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); - - const size_t past_k_bytes = element_size * parameters.batch_size * parameters.kv_sequence_length * parameters.num_heads * parameters.head_size; - const size_t past_v_bytes = element_size * parameters.batch_size * parameters.kv_sequence_length * parameters.num_heads * parameters.v_head_size; - const bool use_temp_k_v_workspace = parameters.pass_past_in_kv || use_memory_efficient_attention || use_flash_attention; - auto temp_k_work_space = use_temp_k_v_workspace ? GetScratchBuffer(past_k_bytes, context->GetComputeStream()) : nullptr; - auto temp_v_work_space = use_temp_k_v_workspace ? GetScratchBuffer(past_v_bytes, context->GetComputeStream()) : nullptr; - typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); - data.key = (nullptr == key || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(key->Data()); - data.value = (nullptr == value || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(value->Data()); + data.key = (nullptr == key) ? nullptr : reinterpret_cast(key->Data()); + data.value = (nullptr == value) ? nullptr : reinterpret_cast(value->Data()); data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); - data.past_key = pass_key_value_as_past ? reinterpret_cast(key->Data()) - : (nullptr == past_key) ? nullptr - : reinterpret_cast(past_key->Data()); - data.past_value = pass_key_value_as_past ? reinterpret_cast(value->Data()) - : (nullptr == past_value) ? nullptr - : reinterpret_cast(past_value->Data()); + data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); - data.has_qkv_workspace = !no_qkv_workspace; - data.workspace = reinterpret_cast(work_space.get()); - data.temp_k_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_k_work_space.get()) : nullptr; - data.temp_v_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_v_work_space.get()) : nullptr; data.output = reinterpret_cast(output->MutableData()); data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); @@ -309,8 +276,41 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.fused_cross_attention_kernel = fused_cross_attention_kernel; data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; - data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); - data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); + + // Cache of cumulated sequence length that could help when sequence length does not change (for example, image model). + // The cache will be initialized only once, and become readonly after that. + if ((data.fused_cross_attention_kernel != nullptr || data.fused_runner != nullptr) && data.mask_index == nullptr) { + cudaStream_t stream = Stream(context); + data.cumulated_sequence_length_q_cache = this->cumulated_sequence_length_q_cache_.TryGet( + parameters.batch_size, parameters.sequence_length, stream); + + if (data.fused_cross_attention_kernel != nullptr) { + data.cumulated_sequence_length_kv_cache = this->cumulated_sequence_length_kv_cache_.TryGet( + parameters.batch_size, parameters.kv_sequence_length, stream); + } + } + + const bool no_qkv_workspace = NoQkvWorkspace(parameters, data); + size_t workspace_bytes = GetAttentionWorkspaceSize(sizeof(T), + parameters.batch_size, + parameters.num_heads, + parameters.head_size, + parameters.v_head_size, + parameters.sequence_length, + parameters.kv_sequence_length, + parameters.total_sequence_length, + fused_runner, + use_flash_attention, + use_fused_cross_attention, + use_memory_efficient_attention, + no_qkv_workspace); + auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + + data.has_qkv_workspace = !no_qkv_workspace; + data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workspace_bytes; + + data.allow_debug_info = kernel_options_->AllowDebugInfo(); if (softmax_lse_accum_buffer != nullptr) { data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); } @@ -318,8 +318,23 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.out_accum = reinterpret_cast(out_accum_buffer.get()); } - cublasHandle_t cublas = GetCublasHandle(context); + if (data.allow_debug_info) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_fp16_runner_ != nullptr) { + debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length); + } + debug_info.Print("MultiHeadAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + + data.PrintDebugInfo(); + } + cublasHandle_t cublas = GetCublasHandle(context); return QkvToContext( device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 26e38dbad9fd..68fd0c9943fc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h" @@ -32,11 +33,16 @@ class MultiHeadAttention final : public CudaKernel { bool disable_fused_cross_attention_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; + + // These mutable members are readonly after they are initialized so that they can be shared among multiple threads. + // Initialization are done only once by the first thread using the resource, so use once_flag to guard each resource. mutable std::unique_ptr fused_fp16_runner_; mutable std::once_flag fused_fp16_runner_created_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; + mutable std::once_flag fused_cross_init_once_flag_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_; + const AttentionKernelOptions* kernel_options_; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index ac2cb5165a94..2521cd49b548 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -297,7 +297,7 @@ struct T2 { }; template -void LaunchAddBiasTranspose( +void AddBiasTransposePacked( const T* input, const T* biases, T* output, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const int v_head_size, @@ -452,7 +452,7 @@ Status FusedScaledDotProductAttention( void* fused_runner = data.fused_runner; ORT_RETURN_IF_NOT(nullptr != fused_runner, "fused_runner cannot be NULL"); - LaunchAddBiasTranspose(data.gemm_buffer, data.bias, data.workspace, + AddBiasTransposePacked(data.gemm_buffer, data.bias, data.workspace, batch_size, sequence_length, num_heads, qk_head_size, v_head_size, AttentionQkvFormat::QKV_BSN3H, data.token_offset, @@ -477,7 +477,7 @@ Status FusedScaledDotProductAttentionCutlass( const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - LaunchAddBiasTranspose(data.gemm_buffer, data.bias, data.workspace, + AddBiasTransposePacked(data.gemm_buffer, data.bias, data.workspace, batch_size, sequence_length, num_heads, qk_head_size, v_head_size, AttentionQkvFormat::Q_K_V_BSNH, data.token_offset, @@ -564,7 +564,7 @@ Status UnfusedScaledDotProductAttention( T* k = q + elements_q; T* v = k + elements_k; - LaunchAddBiasTranspose(data.gemm_buffer, data.bias, data.workspace, + AddBiasTransposePacked(data.gemm_buffer, data.bias, data.workspace, batch_size, sequence_length, num_heads, qk_head_size, v_head_size, AttentionQkvFormat::Q_K_V_BNSH, data.token_offset, @@ -657,6 +657,20 @@ Status QkvToContext( return UnfusedScaledDotProductAttention(device_prop, cublas, stream, parameters, data); } +template void AddBiasTransposePacked( + const float* input, const float* biases, float* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + +template void AddBiasTransposePacked( + const half* input, const half* biases, half* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index b4ca0194b08b..e5a4c54f4890 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -502,7 +502,7 @@ struct T2 { }; template -void LaunchTranspose( +void AddBiasTransposePacked( const T* query, const T* key, const T* value, const T* bias, T* output, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const int v_head_size, @@ -566,11 +566,11 @@ Status FusedAttentionTrt( // When packed QKV is used, we can directly pass it to fused runner. Otherwise, we need transpose to BSN3H format. const T* qkv = data.query; if (!data.no_qkv_workspace) { - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::QKV_TN3H, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::QKV_TN3H, + data.token_offset, parameters.token_count, stream); qkv = data.workspace; } @@ -601,11 +601,11 @@ Status FlashAttention( // When separated Q, K, V is used, we can directly use them in Cutlass FMHA. Otherwise, transpose BSN3H to 3BSNH if (!data.no_qkv_workspace) { - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, + data.token_offset, parameters.token_count, stream); } float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) @@ -675,11 +675,11 @@ Status FusedAttentionCutlass( // When separated Q, K, V is used, we can directly use them in Cutlass FMHA. Otherwise, transpose BSN3H to 3BSNH if (!data.no_qkv_workspace) { - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, + data.token_offset, parameters.token_count, stream); } MemoryEfficientAttentionParams p; @@ -746,11 +746,11 @@ Status UnfusedAttention( const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); // Q, K and V pointers when fused attention is not used - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::Q_K_V_BNSH, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_BNSH, + data.token_offset, parameters.token_count, stream); T* qkv = data.workspace; T* q = qkv; @@ -848,6 +848,22 @@ Status QkvToContext( return UnfusedAttention(device_prop, cublas, stream, parameters, data); } +template void AddBiasTransposePacked( + const half* query, const half* key, const half* value, const half* bias, half* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat source_format, AttentionQkvFormat target_format, + const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + +template void AddBiasTransposePacked( + const float* query, const float* key, const float* value, const float* bias, float* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat source_format, AttentionQkvFormat target_format, + const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc b/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc index 79f0a18ba515..c1459cb9d08d 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc @@ -21,7 +21,7 @@ namespace cuda { T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); + onnxruntime::cuda::Conv); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) diff --git a/onnxruntime/contrib_ops/cuda/fused_conv.cc b/onnxruntime/contrib_ops/cuda/fused_conv.cc index e126f8bcb3d1..279df73ee3d4 100644 --- a/onnxruntime/contrib_ops/cuda/fused_conv.cc +++ b/onnxruntime/contrib_ops/cuda/fused_conv.cc @@ -3,17 +3,66 @@ #include "core/common/status.h" #include "core/providers/cuda/nn/conv.h" +#include "core/providers/cuda/tensor/slice.h" #include "core/providers/cuda/cuda_common.h" namespace onnxruntime { namespace contrib { namespace cuda { +Status SliceOutUnwantedOutputSection(cudaStream_t stream, + const void* input_data, gsl::span input_dims, + void* output_data, + const gsl::span& output_dims, + const gsl::span& starts, + const gsl::span& ends, + const gsl::span& axes, + size_t element_size) { + SliceOp::PrepareForComputeMetadata compute_metadata(input_dims); + + ORT_THROW_IF_ERROR(SliceBase::PrepareForCompute(starts, ends, axes, compute_metadata)); + + // As a sanity check, ensure that the slice operator's output shape matches with the expected output shape + ORT_ENFORCE(SpanEq(gsl::make_span(compute_metadata.output_dims_), output_dims)); + + return ::onnxruntime::cuda::SliceCuda::Impl(stream, input_data, input_dims, output_data, + compute_metadata, element_size); +} + +static cudnnStatus_t GetWorkspaceSize(cudnnHandle_t handle, + const ::onnxruntime::cuda::CudnnConvState& s, + cudnnConvolutionFwdAlgo_t algo, + size_t* sz) { + return cudnnGetConvolutionForwardWorkspaceSize(handle, s.x_tensor, s.w_desc, s.conv_desc, s.y_tensor, algo, sz); +} + +static size_t GetMaxWorkspaceSize(cudnnHandle_t handle, + const ::onnxruntime::cuda::CudnnConvState& s, + const cudnnConvolutionFwdAlgo_t* algo, int n_algo) { + // TODO: get maximum available size from memory arena + size_t free, total; + CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); + // Assuming 10% of fragmentation + free = static_cast(static_cast(free) * 0.9); + size_t max_ws_size = 0; + for (int i = 0; i < n_algo; i++) { + cudnnStatus_t err; + size_t sz; + err = GetWorkspaceSize(handle, s, algo[i], &sz); + if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size || sz > free) continue; + max_ws_size = sz; + } + return max_ws_size; +} + template -class FusedConv : public onnxruntime::cuda::Conv { +class FusedConv : public onnxruntime::cuda::CudaKernel { + using CudaT = typename ::onnxruntime::cuda::ToCudaType::MappedType; + public: - using Base = onnxruntime::cuda::Conv; - FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::Conv(info) { + FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::CudaKernel(info), conv_attrs_(info) { + auto pads_size = conv_attrs_.pads.size(); + ORT_ENFORCE(pads_size % 2 == 0); std::string activation; ORT_THROW_IF_ERROR(info.GetAttr("activation", &activation)); ORT_THROW_IF_ERROR(MapMode(activation)); @@ -32,66 +81,331 @@ class FusedConv : public onnxruntime::cuda::Conv { } } + Status UpdateState(OpKernelContext* context, bool bias_expected) const { + // set X + const Tensor* X = context->Input(0); + const TensorShape& x_shape = X->Shape(); + const auto x_dims = x_shape.AsShapeVector(); + s_.x_data = reinterpret_cast(X->Data()); + s_.element_size = X->DataType()->Size(); + // set W + const Tensor* W = context->Input(1); + const TensorShape& w_shape = W->Shape(); + auto w_dims = w_shape.AsShapeVector(); + s_.w_data = reinterpret_cast(W->Data()); + + // set B + if (context->InputCount() >= 3) { + const Tensor* B = context->Input(2); + s_.b_data = reinterpret_cast(B->Data()); + } else { + s_.b_data = nullptr; + } + // set Z + if (context->InputCount() >= 4) { + const Tensor* Z = context->Input(3); + ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().GetDims(), + ::onnxruntime::cuda::CudnnTensor::GetDataType())); + s_.z_data = reinterpret_cast(Z->Data()); + } else { + s_.z_data = nullptr; + } + bool input_dims_changed = (s_.last_x_dims != x_dims); + bool w_dims_changed = (s_.last_w_dims != w_dims); + if (input_dims_changed || w_dims_changed) { + if (input_dims_changed) + s_.last_x_dims = gsl::make_span(x_dims); + + if (w_dims_changed) { + s_.last_w_dims = gsl::make_span(w_dims); + s_.cached_benchmark_results.clear(); + } + + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape())); + + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); + + const size_t kernel_rank = kernel_shape.size(); + + ConvPadVector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(kernel_rank * 2, 0); + } + TensorShapeVector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(kernel_rank, 1); + } + TensorShapeVector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(kernel_rank, 1); + } + + TensorShapeVector y_dims; + y_dims.reserve(2 + kernel_rank); // add 2 to account for 'N' and 'C' + + const int64_t N = X->Shape()[0]; + const int64_t M = W->Shape()[0]; + y_dims.insert(y_dims.begin(), {N, M}); + + bool post_slicing_required = false; + TensorShapeVector slice_starts; + slice_starts.reserve(kernel_rank); + + TensorShapeVector slice_ends; + slice_ends.reserve(kernel_rank); + + TensorShapeVector slice_axes; + slice_axes.reserve(kernel_rank); + + constexpr size_t spatial_dim_start = 2; + const size_t spatial_dim_end = spatial_dim_start + kernel_rank; + TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end); + + TensorShapeVector y_dims_with_adjusted_pads(y_dims); + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(spatial_shape, kernel_shape, + strides, dilations, pads, y_dims, + y_dims_with_adjusted_pads, post_slicing_required, + slice_starts, slice_ends, slice_axes)); + + ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size()); + s_.y_dims = gsl::make_span(y_dims); + s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads; + s_.post_slicing_required = post_slicing_required; + s_.slice_starts = slice_starts; + s_.slice_ends = slice_ends; + s_.slice_axes = slice_axes; + + s_.Y = context->Output(0, TensorShape(s_.y_dims)); + if (post_slicing_required) { + // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. + s_.memory_for_cudnn_conv_results = + GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, + context->GetComputeStream()); + s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); + } else { + // No post slicing needed. Fill the output tensor's buffer directly. + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + } + + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); + + TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; + TensorShapeVector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads; + if (kernel_rank < 2) { + // TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D] + // especially for EXHAUSTIVE algo search which may result in a better algo selection. + // ORTModule uses different algo search options (HEURISTIC, and use max workspace size) compared to + // inference build (EXHAUSTIVE, 32M workspace size). We observed better perf when we pad input shape + // [N,C,D] to [N,C,1,D], expecially on A100, and especially for ConvGrad. + // PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems + // to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT. + // See PR #7348 and #7702 for more context. + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); + w_dims.insert(w_dims.begin() + 2, 1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + x_dims_cudnn.push_back(1); + y_dims_cudnn.push_back(1); + w_dims.push_back(1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + } + + if (w_dims_changed) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, ::onnxruntime::cuda::CudnnTensor::GetDataType())); + } + + // We must delay returning early until here so that the weight dims have been cached properly + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, ::onnxruntime::cuda::CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, ::onnxruntime::cuda::CudnnTensor::GetDataType())); + + ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, + ::onnxruntime::cuda::CudnnTensor::GetDataType(), UseTF32())); + + if (context->InputCount() >= 3) { + const Tensor* B = context->Input(2); + const auto& b_shape = B->Shape(); + ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); + TensorShapeVector b_dims(2 + kernel_shape.size(), 1); + b_dims[1] = b_shape[0]; + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, ::onnxruntime::cuda::CudnnTensor::GetDataType())); + // s_.b_data = reinterpret_cast(B->Data()); + } else if (bias_expected) { + TensorShapeVector b_dims(2 + kernel_shape.size(), 1); + b_dims[1] = w_dims[0]; + auto malloc_size = b_dims[1] * sizeof(CudaT); + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, ::onnxruntime::cuda::CudnnTensor::GetDataType())); + if (s_.b_zero) { + CUDA_CALL_THROW(cudaFree(s_.b_zero)); + s_.b_zero = nullptr; + } + CUDA_CALL_THROW(cudaMalloc(&s_.b_zero, malloc_size)); + CUDA_CALL_THROW(cudaMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); + } + + if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { + // set math type to tensor core before algorithm search + if constexpr (std::is_same::value) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } + + cudnnConvolutionFwdAlgoPerf_t perf; + int algo_count = 1; + int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); + ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", + cudnn_conv_algo); + switch (cudnn_conv_algo) { + case 0: { + static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; + size_t max_ws_size = cuda_ep->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetCudnnHandle(context), + s_, kAllAlgos, num_algos) + : ::onnxruntime::cuda::AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx( + GetCudnnHandle(context), + s_.x_tensor, + s_.x_data, + s_.w_desc, + s_.w_data, + s_.conv_desc, + s_.y_tensor, + s_.y_data, + 1, // requestedAlgoCount + &algo_count, // returnedAlgoCount + &perf, + algo_search_workspace.get(), + max_ws_size)); + break; + } + case 1: + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7( + GetCudnnHandle(context), + s_.x_tensor, + s_.w_desc, + s_.conv_desc, + s_.y_tensor, + 1, // requestedAlgoCount + &algo_count, // returnedAlgoCount + &perf)); + break; + + default: + perf.algo = kDefaultConvAlgo; + CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); + + if constexpr (std::is_same::value) { + perf.mathType = CUDNN_TENSOR_OP_MATH; + } else if (std::is_same::value && !UseTF32()) { + perf.mathType = CUDNN_FMA_MATH; + } else { + perf.mathType = CUDNN_DEFAULT_MATH; + } + } + s_.cached_benchmark_results.insert(x_dims_cudnn, {perf.algo, perf.memory, perf.mathType}); + } + const auto& perf = s_.cached_benchmark_results.at(x_dims_cudnn); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); + s_.algo = perf.algo; + s_.workspace_bytes = perf.memory; + } else { + // set Y + s_.Y = context->Output(0, s_.y_dims); + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + if (s_.post_slicing_required) { + s_.memory_for_cudnn_conv_results = GetScratchBuffer( + TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); + s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); + } else { + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + } + } + return Status::OK(); + } + Status ComputeInternal(OpKernelContext* context) const override { - std::lock_guard lock(Base::s_.mutex); + std::lock_guard lock(s_.mutex); auto cudnnHandle = this->GetCudnnHandle(context); - ORT_RETURN_IF_ERROR(Base::UpdateState(context, true)); - if (Base::s_.Y->Shape().Size() == 0) { + ORT_RETURN_IF_ERROR(UpdateState(context, true)); + if (s_.Y->Shape().Size() == 0) { return Status::OK(); } - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - typedef typename onnxruntime::cuda::ToCudaType::MappedType CudaT; + bool has_z = nullptr != s_.z_data; + bool has_b = nullptr != s_.b_data; const auto alpha = onnxruntime::cuda::Consts::One; const auto beta = onnxruntime::cuda::Consts::Zero; - IAllocatorUniquePtr workspace = Base::GetWorkSpace(context->GetComputeStream()); + IAllocatorUniquePtr workspace = GetWorkSpace(context->GetComputeStream()); auto cudnn_status = cudnnConvolutionBiasActivationForward(cudnnHandle, &alpha, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.w_desc, - Base::s_.w_data, - Base::s_.conv_desc, - Base::s_.algo, + s_.x_tensor, + s_.x_data, + s_.w_desc, + s_.w_data, + s_.conv_desc, + s_.algo, workspace.get(), - Base::s_.workspace_bytes, + s_.workspace_bytes, has_z ? &alpha : &beta, - has_z ? Base::s_.z_tensor : Base::s_.y_tensor, - has_z ? Base::s_.z_data : Base::s_.y_data, - Base::s_.b_tensor, - has_b ? Base::s_.b_data : Base::s_.b_zero, + has_z ? s_.z_tensor : s_.y_tensor, + has_z ? s_.z_data : s_.y_data, + s_.b_tensor, + has_b ? s_.b_data : s_.b_zero, activation_desc_, - Base::s_.y_tensor, - Base::s_.y_data); + s_.y_tensor, + s_.y_data); if (CUDNN_STATUS_SUCCESS != cudnn_status) { CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(cudnnHandle, &alpha, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.w_desc, - Base::s_.w_data, - Base::s_.conv_desc, - Base::s_.algo, + s_.x_tensor, + s_.x_data, + s_.w_desc, + s_.w_data, + s_.conv_desc, + s_.algo, workspace.get(), - Base::s_.workspace_bytes, + s_.workspace_bytes, &beta, - Base::s_.y_tensor, - Base::s_.y_data)); + s_.y_tensor, + s_.y_data)); if (has_b) { - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnnHandle, &alpha, Base::s_.b_tensor, Base::s_.b_data, - &alpha, Base::s_.y_tensor, Base::s_.y_data)); + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnnHandle, &alpha, s_.b_tensor, s_.b_data, + &alpha, s_.y_tensor, s_.y_data)); } if (has_z) { - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnnHandle, &alpha, Base::s_.z_tensor, Base::s_.z_data, - &alpha, Base::s_.y_tensor, Base::s_.y_data)); + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnnHandle, &alpha, s_.z_tensor, s_.z_data, + &alpha, s_.y_tensor, s_.y_data)); } - CUDNN_RETURN_IF_ERROR(cudnnActivationForward(cudnnHandle, activation_desc_, &alpha, Base::s_.y_tensor, - Base::s_.y_data, &beta, Base::s_.y_tensor, Base::s_.y_data)); + CUDNN_RETURN_IF_ERROR(cudnnActivationForward(cudnnHandle, activation_desc_, &alpha, s_.y_tensor, + s_.y_data, &beta, s_.y_tensor, s_.y_data)); } - if (Base::s_.post_slicing_required) { - ORT_RETURN_IF_ERROR(onnxruntime::cuda::SliceOutUnwantedOutputSection( - this->Stream(context), Base::s_.y_data, Base::s_.y_dims_with_adjusted_pads, Base::s_.Y->MutableDataRaw(), - Base::s_.y_dims.GetDims(), Base::s_.slice_starts, Base::s_.slice_ends, Base::s_.slice_axes, Base::s_.element_size)); + if (s_.post_slicing_required) { + ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection( + this->Stream(context), s_.y_data, s_.y_dims_with_adjusted_pads, s_.Y->MutableDataRaw(), + s_.y_dims.GetDims(), s_.slice_starts, s_.slice_ends, s_.slice_axes, s_.element_size)); } return Status::OK(); } @@ -107,6 +421,25 @@ class FusedConv : public onnxruntime::cuda::Conv { } return Status::OK(); } + + inline IAllocatorUniquePtr GetWorkSpace(onnxruntime::Stream* stream) const { + return GetScratchBuffer(s_.workspace_bytes, stream); + } + + ConvAttributes conv_attrs_; + mutable ::onnxruntime::cuda::CudnnConvState s_; + constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + constexpr static cudnnConvolutionFwdAlgo_t kAllAlgos[] = { + CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_FFT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, + }; + cudnnActivationMode_t activation_mode_; cudnnActivationDescriptor_t activation_desc_ = nullptr; }; @@ -122,4 +455,4 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( } // namespace cuda } // namespace contrib -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 168c69c69f00..b62e566d43f8 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -190,7 +190,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { fused_runner, use_flash_attention, use_fused_cross_attention, - use_memory_efficient_attention); + use_memory_efficient_attention, + true); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -208,6 +209,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workSpaceSize; data.output = reinterpret_cast(output->MutableData()); if (nullptr != present) { data.present = reinterpret_cast(present->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc index e10c2ec63fd5..6d52ff728279 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc @@ -13,6 +13,9 @@ namespace cuda { #if DUMP_TENSOR_LEVEL > 0 +// Environment variable to enable/disable GPU Tensor dumping +constexpr const char* kEnableGpuTensorDumper = "ORT_ENABLE_GPU_DUMP"; + // Total number of elements which trigger snippet rather than full dump (default 200). Value 0 disables snippet. constexpr const char* kTensorSnippetThreshold = "ORT_TENSOR_SNIPPET_THRESHOLD"; @@ -202,6 +205,10 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) { DumpGpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } +CudaTensorConsoleDumper::CudaTensorConsoleDumper() { + is_enabled_ = ParseEnvironmentVariableWithDefault(kEnableGpuTensorDumper, 1) != 0; +} + void CudaTensorConsoleDumper::Print(const std::string& value) const { std::cout << value << std::endl; } @@ -329,6 +336,8 @@ void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, } #else +CudaTensorConsoleDumper::CudaTensorConsoleDumper() { +} void CudaTensorConsoleDumper::Print(const std::string&) const { } diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h index 6ad0ad9a67b7..4f41161cd4a3 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h @@ -13,7 +13,7 @@ namespace cuda { class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { public: - CudaTensorConsoleDumper() = default; + CudaTensorConsoleDumper(); virtual ~CudaTensorConsoleDumper() {} void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index b0ed3ff82226..b94971ffd44d 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -119,7 +119,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE; return Status::OK(); } @@ -128,7 +128,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH; return Status::OK(); } @@ -136,7 +136,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH; return Status::OK(); } @@ -146,7 +146,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH; return Status::OK(); } @@ -154,7 +154,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH; return Status::OK(); } diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 349df045becf..d593bc001282 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -132,12 +132,6 @@ class CompatRocblasMathModeSetter { } }; -enum AttentionType { - kAttention, - kMultiHeadAttention, - kDecoderMaskedMultiHeadAttention, -}; - enum AttentionMode { // Q,K,V,PastK,PastV,PresentK,PresentV QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE, diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index 09e7d61b71db..5997daaca6e8 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -122,9 +122,11 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, - &attn, num_heads_, - mask_filter_value_, scale_, false, /*is_unidirectional_*/ - past_present_share_buffer_, false, device_prop.maxThreadsPerBlock)); + &attn, num_heads_, + mask_filter_value_, scale_, false, /*is_unidirectional_*/ + past_present_share_buffer_, + attn_type_, + device_prop.maxThreadsPerBlock)); if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index e2c06fbdfa62..850cb167a3ec 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -4,6 +4,7 @@ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) #include "node_unit.h" +#include #include "core/graph/graph_viewer.h" namespace onnxruntime { @@ -272,6 +273,20 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g } } +NodeUnit::NodeUnit(gsl::span dq_nodes, const Node& target_node, + gsl::span q_nodes, Type unit_type, + gsl::span inputs, gsl::span outputs, + size_t input_edge_count, Node::EdgeSet output_edges) + : dq_nodes_(dq_nodes.begin(), dq_nodes.end()), + target_node_(target_node), + q_nodes_(q_nodes.begin(), q_nodes.end()), + type_(unit_type), + inputs_(inputs.begin(), inputs.end()), + outputs_(outputs.begin(), outputs.end()), + input_edge_count_(input_edge_count), + output_edges_(std::move(output_edges)) { +} + const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); } const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); } const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); } diff --git a/onnxruntime/core/framework/node_unit.h b/onnxruntime/core/framework/node_unit.h index e84e62479162..50bd423d2f54 100644 --- a/onnxruntime/core/framework/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -68,6 +68,10 @@ class NodeUnit { public: explicit NodeUnit(const Node& node); explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group); + NodeUnit(gsl::span dq_nodes, const Node& target_node, + gsl::span q_nodes, Type unit_type, + gsl::span inputs, gsl::span outputs, + size_t input_edge_count, Node::EdgeSet output_edges); Type UnitType() const noexcept { return type_; } diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index a88f36f63639..ddb0c3356e54 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1486,7 +1486,8 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string #include #include +#include #include #include "core/common/flatbuffers.h" @@ -303,6 +304,10 @@ class SessionState { const InlinedHashSet* GetToBeExecutedRange(gsl::span fetch_mlvalue_idxs) const; #endif + std::unordered_map>* GetMutableBufferedTensors() { + return &name_to_buffered_tensor_; + } + Status FinalizeSessionState(const std::basic_string& graph_loc, const KernelRegistryManager& kernel_registry_manager, bool remove_initializers = true, @@ -562,6 +567,12 @@ class SessionState { // flag to indicate whether current session using any EP that create device stream dynamically. bool has_device_stream_enabled_ep_ = false; #endif + + // Holds the tensors which provide memory buffer for TensorProtos + // Use case: in optimizer, transform a TensorProto to a new TensorProto whose the memory buffer is + // allocated by CPU instead by protobuf's arena. Arena style memory allocators do not fully release + // a instance's memory which may result large memory consumption, which is a tradeoff for speed. + std::unordered_map> name_to_buffered_tensor_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 059de8e3c8c4..b13b0cd27496 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -3,6 +3,8 @@ #include #include +#include +#include #include #include @@ -61,17 +63,23 @@ struct ExtDataValueDeleter { // given a tensor proto with external data return an OrtValue with a tensor for // that data; the pointers for the tensor data and the tensor itself are owned -// by the OrtValue's deleter +// by the OrtValue's deleter. +// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and +// buffered_tensor is not null, buffered_tensor holds the real buffer pointed +// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter +// should release the buffer when tensor_proto is released. static inline common::Status ExtDataTensorProtoToTensor(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, - Tensor& tensor, OrtCallback& ext_data_deleter) { + Tensor& tensor, OrtCallback& ext_data_deleter, + Tensor* buffered_tensor = nullptr) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); void* ext_data_buf = nullptr; SafeInt ext_data_len = 0; ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto, - ext_data_buf, ext_data_len, ext_data_deleter)); + ext_data_buf, ext_data_len, ext_data_deleter, + buffered_tensor)); // NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be // avoided if the Tensor class implements the do-nothing behavior when given a @@ -83,16 +91,24 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env, return common::Status::OK(); } +// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and +// buffered_tensor is not null, buffered_tensor holds the real buffer pointed +// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter +// should release the buffer when tensor_proto is released. static common::Status DeserializeTensorProto(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer* m, const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc, OrtValue& ort_value, const DataTransferManager& data_transfer_mgr, - bool use_device_allocator_for_initializers = false) { + bool use_device_allocator_for_initializers = false, + Tensor* buffered_tensor = nullptr) { if (bool(alloc) == (m != nullptr)) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DeserializeTensorProto() takes either pre-allocated buffer or an allocator!"); } + ORT_RETURN_IF(buffered_tensor && !utils::HasExternalData(tensor_proto), + "With buffered tensor, tensor proto must use external location and point to buffered tensor"); + // Get shape and type of the tensor, and allocate the empty tensor TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); @@ -123,7 +139,8 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st // utilize the mmap'd buffer directly by calling ExtDataTensorProtoToTensor. If we called // TensorProtoToTensor it would copy the data, causing unnecessary overhead OrtCallback ext_data_deleter; - ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, ext_data_deleter)); + ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, + ext_data_deleter, buffered_tensor)); ExtDataValueDeleter deleter{ext_data_deleter, p_tensor.get()}; @@ -154,7 +171,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st std::optional scoped_ort_callback_invoker; if (utils::HasExternalData(tensor_proto)) { ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_deserialize_tensor, - ext_data_deleter)); + ext_data_deleter, buffered_tensor)); scoped_ort_callback_invoker = ScopedOrtCallbackInvoker(ext_data_deleter); } else { ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, *p_deserialize_tensor)); @@ -187,7 +204,8 @@ common::Status SaveInitializedTensors( const logging::Logger& logger, const DataTransferManager& data_transfer_mgr, const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, - const MemoryProfileFunction& memory_profile_func) { + const MemoryProfileFunction& memory_profile_func, + std::unordered_map>& buffered_tensors) { LOGS(logger, INFO) << "Saving initialized tensors."; ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated."); @@ -307,9 +325,16 @@ common::Status SaveInitializedTensors( bool use_device_allocator_for_initializers = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1"; + Tensor* p_tensor = nullptr; + if (auto iter = buffered_tensors.find(name); + iter != buffered_tensors.end()) { + p_tensor = iter->second.release(); + buffered_tensors.erase(iter); + } + Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, default_cpu_alloc, ort_value, data_transfer_mgr, - use_device_allocator_for_initializers); + use_device_allocator_for_initializers, p_tensor); if (!st.IsOK()) { std::ostringstream oss; oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage(); diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index af44c35fbb7f..499222b6ec61 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -3,6 +3,9 @@ #pragma once #include +#include +#include +#include #include "core/common/const_pointer_container.h" #include "core/framework/allocator.h" @@ -44,7 +47,8 @@ common::Status SaveInitializedTensors( const DataTransferManager& data_transfer_mgr, const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, - const MemoryProfileFunction& memory_profile_func); + const MemoryProfileFunction& memory_profile_func, + std::unordered_map>& buffered_tensors); common::Status SaveInputOutputNamesToNodeMapping(const GraphViewer& graph, SessionState& session_state, diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 4ecd61962d79..cbd53298ab2a 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -987,7 +987,8 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, - SafeInt& ext_data_len, OrtCallback& ext_data_deleter) { + SafeInt& ext_data_len, OrtCallback& ext_data_deleter, + Tensor* buffered_tensor) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); std::basic_string tensor_proto_dir; if (!model_path.empty()) { @@ -1003,7 +1004,12 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo // the value in location is the memory address of the data ext_data_buf = reinterpret_cast(file_offset); ext_data_len = raw_data_safe_len; - ext_data_deleter = OrtCallback{nullptr, nullptr}; + if (buffered_tensor) { + ext_data_deleter = OrtCallback{[](void* p) noexcept { delete reinterpret_cast(p); }, + reinterpret_cast(buffered_tensor)}; + } else { + ext_data_deleter = OrtCallback{nullptr, nullptr}; + } } else { #if defined(__wasm__) ORT_RETURN_IF(file_offset < 0 || file_offset + raw_data_safe_len >= 4294967296, @@ -1241,7 +1247,9 @@ ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto return CApiElementTypeFromProtoType(tensor_proto.data_type()); } -ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name) { +ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, + const std::string& tensor_proto_name, + bool use_tensor_buffer) { // Set name, dimensions, type, and data of the TensorProto. ONNX_NAMESPACE::TensorProto tensor_proto; @@ -1259,6 +1267,28 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std: for (; f < end; ++f) { *mutable_string_data->Add() = *f; } + } else if (use_tensor_buffer && tensor.SizeInBytes() > 127) { + // The logic aligns with + // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_flatbuffers_utils.cc#L302 + const auto* raw_data = tensor.DataRaw(); + ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor."); + static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); + tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + // we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto. + // use intptr_t as OFFSET_TYPE is signed. in theory you could get a weird looking value if the address uses the + // high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path. + auto offset = narrow(reinterpret_cast(raw_data)); + + ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("location"); + entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); + entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("offset"); + entry->set_value(std::to_string(offset)); + entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("length"); + entry->set_value(std::to_string(tensor.SizeInBytes())); } else { utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes()); } diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index e5197adcb94e..2af1f080be7e 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -114,14 +114,22 @@ common::Status TensorProtoToTensor(const Env& env, const std::filesystem::path& const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor); -/** Creates a TensorProto from a Tensor. - @param[in] tensor the Tensor whose data and shape will be used to create the TensorProto. - @param[in] tensor_proto_name the name of the TensorProto. - @return the TensorProto. - - Note: Method currently requires that data is in little-endian format. +/** + * @brief Creates a TensorProto from a Tensor. + * @param[in] tensor the Tensor whose data and shape will be used to create the TensorProto. + * @param[in] tensor_proto_name the name of the TensorProto. + * @param[in] use_tensor_buffer the tensor proto is set to use external location, with + * 'location' set to onnxruntime::utils::kTensorProtoMemoryAddressTag + * 'offset' set to tensor's memory location, and 'length' set to tensor's + * memory size. The caller is responsible to maintain the lifetime of + * the allocated memory buffer. Use with caution. + * @return the TensorProto. + * + * Note: Method currently requires that data is in little-endian format. */ -ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name); +ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, + const std::string& tensor_proto_name, + bool use_tensor_buffer = false); ONNXTensorElementDataType CApiElementTypeFromProtoType(int type); ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto& tensor_proto); @@ -141,10 +149,15 @@ constexpr const ORTCHAR_T* kTensorProtoMemoryAddressTag = ORT_TSTR("*/_ORT_MEM_A // Given a tensor proto with external data obtain a pointer to the data and its length. // The ext_data_deleter argument is updated with a callback that owns/releases the data. +// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and +// buffered_tensor is not null, buffered_tensor holds the real buffer pointed +// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter +// should release the buffer when tensor_proto is released. common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, SafeInt& ext_data_len, - OrtCallback& ext_data_deleter); + OrtCallback& ext_data_deleter, + Tensor* buffered_tensor = nullptr); // Convert the AttributeProto from a Constant node into a TensorProto that can be used as an initializer // If AttributeProto contains a TensorProto, this tensor proto is converted as is including the case when the diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 442a0db933d6..e950d68947b9 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3254,27 +3254,6 @@ Status Graph::PerformTypeAndShapeInferencing(const ResolveOptions& options) { return Status::OK(); } -void Graph::FindAllSubgraphs(std::vector& subgraphs) { - for (auto& node : Nodes()) { - for (auto& subgraph : node.MutableSubgraphs()) { - subgraphs.push_back(subgraph.get()); - subgraph->FindAllSubgraphs(subgraphs); - } - } -} - -Status Graph::ForThisAndAllSubgraphs(const std::vector& subgraphs, std::function func) { - auto status = func(*this); - ORT_RETURN_IF_ERROR(status); - - for (auto& subgraph : subgraphs) { - status = func(*subgraph); - ORT_RETURN_IF_ERROR(status); - } - - return status; -} - Status Graph::Resolve(const ResolveOptions& options) { if (parent_graph_) { // Resolve must start at the top level graph in-order to handle outer scope @@ -3387,6 +3366,39 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor.name(), &t)); } } + +void Graph::FindAllSubgraphs(std::vector& subgraphs) { + for (auto& node : Nodes()) { + for (auto& subgraph : node.MutableSubgraphs()) { + subgraphs.push_back(subgraph.get()); + subgraph->FindAllSubgraphs(subgraphs); + } + } +} + +Status Graph::ForThisAndAllSubgraphs(const std::vector& subgraphs, std::function func) { + auto status = func(*this); + ORT_RETURN_IF_ERROR(status); + + for (auto& subgraph : subgraphs) { + status = func(*subgraph); + ORT_RETURN_IF_ERROR(status); + } + + return status; +} + +Status Graph::RemovedUnusedInitializersOrtFormat() { + std::vector all_subgraphs; + FindAllSubgraphs(all_subgraphs); + auto cleanup_func = [](Graph& graph) { + graph.CleanUnusedInitializersAndNodeArgs(nullptr); + return Status::OK(); + }; + + auto result = ForThisAndAllSubgraphs(all_subgraphs, cleanup_func); + return result; +} #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const std::string& Graph::Name() const noexcept { @@ -4122,6 +4134,9 @@ void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const } } +#endif // !defined(ORT_MINIMAL_BUILD) + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) void Graph::CleanUnusedInitializersAndNodeArgs(const std::unordered_set* initializer_names_to_preserve) { // Node Args being used std::unordered_set used_args; @@ -4253,8 +4268,7 @@ void Graph::CleanUnusedInitializersAndNodeArgs(const std::unordered_set* PostProcessor = nullptr; @@ -159,14 +161,29 @@ MlasSQNBitGemmPackQuantBDataSize( /** * @brief Packs the quantized B data in a format that the kernel expects. * - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) - * @param[in] BlkLen number of quantized values per block - * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) - * @param[in] QuantBData quantized B data - * @param[out] PackedQuantBData packed quantized B data - * @param[in] ThreadPool optional thread pool to use + * If the function is called without QuantBScale and QuantBZeroPoint, + * it just packs QuantBData into PackedQuantBDataAndOrBlkSum. + * + * If the function is called with QuantBData, QuantBScale, and QuantBZeroPoint + * additional BlkSum (Scale * zeropoint) is computed and stored at the second part of PackedQuantBDataAndOrBlkSum. + * + * Because ORT OpKernel::PrePack is called for each input (in this case, QuantBData, + * QuantBScale, and QuantBZeroPoint) separately, this function may be called 3 times, first with QuantBData, + * and then QuantBScale and QuantBZeroPoint. When the function is called with QuantBScale without QuantBZeroPoint, + * BlkSum is computed with default zero point 8 and stored at the second part of PackedQuantBDataAndOrBlkSum. + * If there is a third call with QuantBZeroPoint, BlkSum is recomputed/adjusted with provided zeropoint. + * + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + * @param[in] QuantBData quantized B data + * @param[in] PackedQuantBDataAndOrBlkSum buffer to store packed quantized B data and/or BlkSum + * @param[in] QuantBScale quantized B scale + * @param[in] has_zp_input whether QuantBZeroPoint is provided + * @param[in] QuantBZeroPoint quantized B zero point + * @param[in] ThreadPool thread pool to use (no parallel if nullptr) */ void MLASCALL MlasSQNBitGemmPackQuantBData( @@ -176,6 +193,9 @@ MlasSQNBitGemmPackQuantBData( size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, - void* PackedQuantBData, - MLAS_THREADPOOL* ThreadPool = nullptr + void* PackedQuantBDataAndOrBlkSum, + const void* QuantBScale, + bool has_zp_input, + const void* QuantBZeroPoint, + MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/mlas/lib/dwconv.cpp b/onnxruntime/core/mlas/lib/dwconv.cpp index 15511d2d8cea..d48d9cbb1750 100644 --- a/onnxruntime/core/mlas/lib/dwconv.cpp +++ b/onnxruntime/core/mlas/lib/dwconv.cpp @@ -14,7 +14,6 @@ Module Name: --*/ - #include "fp16_common.h" #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -24,19 +23,20 @@ void MlasConvDepthwiseKernel( const _mlas_fp16_* const* Input, const _mlas_fp16_* Filter, + const _mlas_fp16_* Bias, _mlas_fp16_* Output, size_t Channels, size_t OutputCount, size_t KernelSize, MLAS_HALF_GEMM_POSTPROCESSOR* PostProc - ) +) { while (OutputCount > 0) { size_t ChannelOffset = 0; size_t c = Channels; while (c >= 8) { - MLAS_FLOAT16X8 Accumulator = MlasZeroFloat16x8(); + MLAS_FLOAT16X8 Accumulator = Bias == nullptr ? MlasZeroFloat16x8() : MlasLoadFloat16x8(&Bias[ChannelOffset]); size_t ChannelKernelOffset = ChannelOffset; for (size_t k = 0; k < KernelSize; k++) { @@ -54,7 +54,7 @@ MlasConvDepthwiseKernel( } if (c >= 4) { - MLAS_FLOAT16X4 Accumulator = MlasZeroFloat16x4(); + MLAS_FLOAT16X4 Accumulator = Bias == nullptr ? MlasZeroFloat16x4() : MlasLoadFloat16x4(&Bias[ChannelOffset]); size_t ChannelKernelOffset = ChannelOffset; for (size_t k = 0; k < KernelSize; k++) { @@ -72,7 +72,8 @@ MlasConvDepthwiseKernel( } if (c > 0) { - MLAS_FLOAT16X4 Accumulator = MlasZeroFloat16x4(); + MLAS_FLOAT16X4 Accumulator = + Bias == nullptr ? MlasZeroFloat16x4() : MlasLoadPartialFloat16x4(&Bias[ChannelOffset], c); size_t ChannelKernelOffset = ChannelOffset; for (size_t k = 0; k < KernelSize; k++) { @@ -86,8 +87,7 @@ MlasConvDepthwiseKernel( Output += c; } if (PostProc) { - PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, - Channels); + PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, Channels); } Input += KernelSize; OutputCount -= 1; @@ -101,16 +101,17 @@ void MlasConvDepthwiseKernel( const _mlas_fp16_* const* Input, const _mlas_fp16_* Filter, + const _mlas_fp16_* Bias, _mlas_fp16_* Output, size_t Channels, size_t OutputCount, size_t KernelSize, MLAS_HALF_GEMM_POSTPROCESSOR* PostProc - ) +) { while (OutputCount > 0) { for (size_t ChannelOffset = 0; ChannelOffset < Channels; ChannelOffset++) { - float Accumulator = 0.0f; + float Accumulator = Bias == nullptr ? 0.0f : MLAS_Half2Float(Bias[ChannelOffset]); size_t ChannelKernelOffset = ChannelOffset; for (size_t k = 0; k < KernelSize; k++) { @@ -120,35 +121,36 @@ MlasConvDepthwiseKernel( *Output++ = MLAS_Float2Half(Accumulator); } if (PostProc) { - PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, - Channels); + PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, Channels); } Input += KernelSize; OutputCount -= 1; } } -#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED - +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED void MLASCALL MlasConvDepthwise( const MLAS_FP16* const* Input, const MLAS_FP16* Filter, + const MLAS_FP16* Bias, MLAS_FP16* Output, size_t Channels, size_t OutputCount, size_t KernelSize, MLAS_HALF_GEMM_POSTPROCESSOR* PostProc - ) +) { MlasConvDepthwiseKernel( reinterpret_cast(Input), reinterpret_cast(Filter), + reinterpret_cast(Bias), reinterpret_cast<_mlas_fp16_*>(Output), Channels, OutputCount, KernelSize, - PostProc); + PostProc + ); } diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h index 1fcab870af64..30b66cdb2ea7 100644 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -64,6 +64,23 @@ MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); } +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasLoadPartialFloat16x4(const _mlas_fp16_* Buffer, size_t len) +{ + MLAS_FLOAT16X4 Vector = MlasZeroFloat16x4(); + if ((len & 1) != 0) { + Vector = vreinterpret_f16_u16(vld1_lane_u16(Buffer + (len - 1), vreinterpret_u16_f16(Vector), 0)); + } + if ((len & 2) != 0) { + Vector = vreinterpret_f16_f32(vdup_lane_f32(vreinterpret_f32_f16(Vector), 0)); + Vector = vreinterpret_f16_f32( + vld1_lane_f32(reinterpret_cast(Buffer), vreinterpret_f32_f16(Vector), 0) + ); + } + return Vector; +} + MLAS_FORCEINLINE void MlasStoreFloat16x8(_mlas_fp16_* Buffer, MLAS_FLOAT16X8 Vector) diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 83200187963e..4239e2ecaeb6 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -993,6 +993,8 @@ extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; + extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 859b7c2f560a..ed437f20f7c2 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -409,6 +409,7 @@ Return Value: this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; } #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 81789386a320..a45494ef2e04 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -16,11 +16,10 @@ Module Name: --*/ #include "sqnbitgemm.h" +#include "sqnbitgemm_q8_block.h" #include -#include "sqnbitgemm_q8_block.h" - namespace { @@ -80,9 +79,10 @@ MlasIsSQNBitGemmAvailable( return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; } - case SQNBitGemmVariant_BitWidth4_CompInt8: { - return Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && - Dispatch->QuantizeARow_CompInt8 != nullptr; + case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 + return + (Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr) || + (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr); } default: { return false; @@ -197,6 +197,21 @@ MlasSQNBitGemmPackQuantBDataSize( return 0; } +struct PerGemmQuantAWorkspace { + PerGemmQuantAWorkspace(void* PerGemmWorkspace, size_t M, size_t BlockCountK, size_t BlkLen) + : PerGemmWorkspace_(PerGemmWorkspace), M_(M), BlockCountK_(BlockCountK), BlkLen_(BlkLen) + { + QuantData = (std::byte*)PerGemmWorkspace; + QuantScale = (float*)(QuantData + M * BlockCountK * BlkLen); + BlockSum = QuantScale + M * BlockCountK; + } + std::byte* QuantData; // NxBlockCountKxBlkLen + float* QuantScale; // NxBlockCountK + float* BlockSum; // NxBlockCountK + void* PerGemmWorkspace_; // memory for above data + size_t M_, BlockCountK_, BlkLen_; +}; + void MLASCALL MlasSQNBitGemmPackQuantBData( size_t N, @@ -205,7 +220,10 @@ MlasSQNBitGemmPackQuantBData( size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, - void* PackedQuantBData, + void* PackedQuantBDataAndOrBlkSumWorkspace, + const void* QuantBScale, + bool has_zp_input, + const void* QuantBZeroPoint, MLAS_THREADPOOL* ThreadPool ) { @@ -214,17 +232,37 @@ MlasSQNBitGemmPackQuantBData( return; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBData != nullptr) { - Dispatch->SQ4BitGemmPackQuantBData( - N, - K, - BlkLen, - ComputeType, - static_cast(QuantBData), - static_cast(PackedQuantBData), - ThreadPool - ); - return; + if (BlkBitWidth == 4) { + if (ComputeType == CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(QuantBScale), + has_zp_input, + static_cast(QuantBZeroPoint), + packed_quant_b, + ThreadPool + ); + } else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { + // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. + //assert(QuantBScale == nullptr); + //assert(QuantBZeroPoint == nullptr); + Dispatch->SQ4BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBDataAndOrBlkSumWorkspace), + ThreadPool + ); + return; + } } } @@ -293,7 +331,7 @@ SQ4BitGemm_CompFp32( const float* A = DataParams->A + RangeStartM * lda; - const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; const std::byte* QuantBZeroPoint = (DataParams->QuantBZeroPoint == nullptr) @@ -373,7 +411,6 @@ SQ4BitGemm_CompFp32( if (bias) { AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); } - if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, @@ -383,7 +420,6 @@ SQ4BitGemm_CompFp32( c_blk += ldc * RowsHandled; a_row += lda * RowsHandled; - RowsRemaining -= RowsHandled; } } @@ -402,16 +438,33 @@ SQ4BitGemm_CompInt8( ) { #ifdef MLAS_TARGET_AMD64_IX86 - if (RangeCountM != 1) { - // perf experiment shows fp32 is faster than int8 in M > 1 cases. - // route to fp32 compute before int8 compute is improved. - SQ4BitGemm_CompFp32( - BlkLen, - K, DataParams, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN - ); - return; - } -#endif + PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace = static_cast(PerGemmWorkspace); + constexpr size_t BlkBitWidth = 4; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + + // quant A scale is embedded in QuantData if QuantScale is nullptr. + const size_t lda = k_blks * (per_gemm_quant_a_workspace->QuantScale ? BlkLen : Q8BlkSize(BlkLen)); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const std::byte* QuantA = per_gemm_quant_a_workspace->QuantData + RangeStartM * lda; + const float* QuantAScale = per_gemm_quant_a_workspace->QuantScale + RangeStartM * k_blks; + + assert(RangeStartN % 4 == 0); + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; + const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; +#else constexpr size_t BlkBitWidth = 4; const size_t k_blks = MlasDivRoundup(K, BlkLen); @@ -423,7 +476,7 @@ SQ4BitGemm_CompInt8( const std::byte* QuantA = static_cast(PerGemmWorkspace) + RangeStartM * lda; - const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; const std::byte* QuantBZeroPoint = (DataParams->QuantBZeroPoint == nullptr) @@ -433,6 +486,7 @@ SQ4BitGemm_CompInt8( float* C = DataParams->C + RangeStartM * ldc + RangeStartN; const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; +#endif size_t CountN; for (size_t n = 0; n < RangeCountN; n += CountN) { @@ -446,25 +500,57 @@ SQ4BitGemm_CompInt8( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { - const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) { + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, + RowsHandled, CountN, ldc + ); + } + + c_blk += RowsHandled * ldc; + a_row += RowsHandled * lda; + + RowsRemaining -= RowsHandled; + } + } +#ifdef MLAS_TARGET_AMD64_IX86 + else if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) + { + const float* b_blk_sum = QuantBBlkSum + n * k_blks; + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8( BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias + QuantA, + QuantAScale, + b_col, + b_col_scale, + b_col_zp, + c_blk, + RangeCountM, + CountN, + K, + k_blks, + bias, + ldc, + ABlockSum, + b_blk_sum ); if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, - RowsHandled, CountN, ldc + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc ); } - - c_blk += RowsHandled * ldc; - a_row += RowsHandled * lda; - - RowsRemaining -= RowsHandled; } +#endif } } @@ -496,23 +582,44 @@ InitializeWorkspace_CompInt8( MLAS_UNREFERENCED_PARAMETER(N); const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; + const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; + // TODO: try parallel on BatchN * M threads because BatchN is usually 1. + if (QuantizeARow) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; - const float* ARowPtr = data.A; - std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + for (size_t m = 0; m < M; ++m) { + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - for (size_t m = 0; m < M; ++m) { - QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - - ARowPtr += data.lda; - QuantARowPtr += QuantAStride; - } - }); + ARowPtr += data.lda; + QuantARowPtr += QuantAStride; + } + }); + } else { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + const float* ARowPtr = data.A; + + void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); + std::byte* QuantARowPtr = quant_a_data.QuantData; + float* QuantARowScalePtr = quant_a_data.QuantScale; + float* QuantARowBlkSum = quant_a_data.BlockSum; + for (size_t m = 0; m < M; ++m) { + QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); + ARowPtr += data.lda; + QuantARowPtr += BlockCountK * BlkLen; + QuantARowScalePtr += BlockCountK; + QuantARowBlkSum += BlockCountK; + } + }); + } } struct Operations { @@ -530,7 +637,6 @@ constexpr auto OperationMap = []() { return ops; }(); - } // namespace void MLASCALL @@ -572,12 +678,23 @@ MlasSQNBitGemmBatch( const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + if (ThreadPool == nullptr) { for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { const auto* Data = &DataParams[gemm_i]; void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); + if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); + } else { + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); + } } return; } @@ -627,9 +744,6 @@ MlasSQNBitGemmBatch( const auto gemm_i = tid / ThreadsPerGemm; const auto blk_i = tid % ThreadsPerGemm; const auto* Data = &DataParams[gemm_i]; - void* PerGemmWorkspace = reinterpret_cast( - reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride - ); const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; @@ -640,6 +754,18 @@ MlasSQNBitGemmBatch( const size_t RangeStartN = ThreadIdN * StrideN; const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + void* PerGemmWorkspace = + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; + if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + } else { + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + } }); } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 8321dcc217e9..2da336ca2f0e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -25,12 +25,50 @@ Module Name: #include "mlas_qnbit.h" #include "mlasi.h" +constexpr MLAS_FORCEINLINE size_t +MlasQNBitQuantBBlkSumAlignment() +{ + // 16 floats. this alignment is required by GemmFloatKernel + return 16 * sizeof(float); +} + constexpr MLAS_FORCEINLINE size_t MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) { return BlkLen * BlkBitWidth / 8; } +MLAS_FORCEINLINE void* +MlasAlignAddress(void* addr, const size_t alignment) +{ + const uintptr_t QuantBBlkSumAddr = reinterpret_cast(addr); + addr = (void*)((QuantBBlkSumAddr + alignment - 1) & (~(alignment - 1))); + return addr; +} + +struct PackedQuantBDataStruct { + PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) + : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) + { + // TODO: duplicate code from SQ4BitGemmPackQuantBDataSize + constexpr size_t BlkBitWidth = 4; + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + + // _mm256_load_si256 requires alignment on a 32-byte boundary + PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); + QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); + QuantBBlkSum = (float*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); + PackedQuantBScale = (float*)((std::byte*)QuantBBlkSum + BlkSumSize); + } + std::byte* PackedQuantBData; + float* PackedQuantBScale; + float* QuantBBlkSum; + + void* QuantBWorkspace_; + size_t N_, BlockCountK_, BlkLen_; +}; + template constexpr MLAS_FORCEINLINE size_t MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) @@ -74,6 +112,21 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool + ); + + SQ4BitGemmPackQuantBDataAndSumBlk_Fn* SQ4BitGemmPackQuantBDataAndBlkSum = nullptr; + // // Workspace size calculation function prototypes. // @@ -181,6 +234,45 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { // CompInt8 kernel function prototypes. // + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. + * A and B are block quantized and B is column major. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + * @param ldc Number of elements between adjacent rows of C.. + * @param ABlockSum Supplies the blksum of A. + * @param QuantBBlkSum Supplies the blksum of B. + */ + typedef size_t(SQ4BitGemmKernel_BlkSum_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum + ); + + SQ4BitGemmKernel_BlkSum_CompInt8_Fn* SQ4BitGemmKernel_BlkSum_CompInt8 = nullptr; + /** * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. * A and B are block quantized and B is column major. @@ -235,4 +327,14 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { ); QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr; + + typedef void(QuantizeARowComputeBlkSum_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledGroupSum // scale_k * Sum_blklen(a_i) + ); + QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr; }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 0922f5ef646b..55d86bb9cc18 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -22,6 +22,12 @@ Module Name: #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen64.h" + +#include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h" +#include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h" MLAS_FORCEINLINE __m256 @@ -338,38 +344,92 @@ Q4BitBlkDequantBForSgemm_CompFp32_avx2( } } +template +MLAS_FORCEINLINE +void +SQ4BitGemmKernel_CompInt8_avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + if (BlkLen == 16) { + MlasQ4Int8GemmKernelBlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ4Int8GemmKernelBlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ4Int8GemmKernelBlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } +} + +template MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompInt8_avx2( size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, float* C, size_t CountN, - size_t CountK, + size_t /*CountK*/, size_t BlockStrideQuantB, const float* Bias ) { - if (QuantBZeroPoint != nullptr) { - constexpr bool HasZeroPoint = true; + if (QuantBZeroPoint) { if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + MlasQ4Int8GemmM1KernelBlkLen32Avx2( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -379,36 +439,25 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( Bias ); } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + MlasQ4Int8GemmKernelBlkLen64Avx2( BlkLen, QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, - CountK, BlockStrideQuantB, Bias ); } } else { - constexpr bool HasZeroPoint = false; if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + MlasQ4Int8GemmM1KernelBlkLen32Avx2( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -418,15 +467,15 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( Bias ); } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + MlasQ4Int8GemmKernelBlkLen64Avx2( BlkLen, QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, - CountK, BlockStrideQuantB, Bias ); @@ -434,10 +483,12 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( } } +MLAS_FORCEINLINE size_t -SQ4BitGemmKernel_CompInt8_avx2( - size_t BlkLen, +SQ4BitGemmKernel_BlkSum_CompInt8_avx2( + const size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -446,30 +497,101 @@ SQ4BitGemmKernel_CompInt8_avx2( size_t CountN, size_t CountK, size_t BlockCountK, + const float* Bias, size_t ldc, - const float* Bias + const float* ABlockSum, + const float* QuantBBlkSum ) { - MLAS_UNREFERENCED_PARAMETER(ldc); + if (BlkLen >= 32 && CountM == 1) { + SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias); + return CountM; + } + + SQ4BitGemmKernel_CompInt8_avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); - if (CountM == 0) { - return 0; + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; } + return CountM; +} - SQ4BitGemmM1Kernel_CompInt8_avx2( +size_t +SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen >= 32 && CountM == 1) { + SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias); + return CountM; + } + + SQ4BitGemmKernel_CompInt8_avx2( BlkLen, QuantA, + QuantAScale, QuantBData, QuantBScale, - QuantBZeroPoint, C, + CountM, CountN, CountK, BlockCountK, - Bias + Bias, + ldc ); + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); - return 1; + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; } template @@ -1053,30 +1175,23 @@ SQ4BitGemmM1Kernel_CompFp32_avx2( } } -MLAS_FORCEINLINE __m128i -convert_2_ps_to_epi8(__m256 v0, __m256 v1) -{ - __m256i v0_8_epi32 = _mm256_cvtps_epi32(v0); - __m256i v1_8_epi32 = _mm256_cvtps_epi32(v1); - - __m128i v0_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v0_8_epi32, 0), _mm256_extractf128_si256(v0_8_epi32, 1)); - __m128i v1_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v1_8_epi32, 0), _mm256_extractf128_si256(v1_8_epi32, 1)); - - return _mm_packs_epi16(v0_8_epi16, v1_8_epi16); -} - void MLASCALL QuantizeARow_CompInt8_avx2( size_t BlkLen, const float* A, size_t CountK, - std::byte* QuantA + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) ) { // port from MlasQ80BlkQuantRow assert(BlkLen % 16 == 0); const __m256 signBit = _mm256_set1_ps(-0.0f); + const __m256i one_16_epi16 = _mm256_srli_epi16( + _mm256_cmpeq_epi16(_mm256_castps_si256(signBit), _mm256_castps_si256(signBit)), 15); int8_t* blob = reinterpret_cast(QuantA); + float* scale_ptr = QuantAScale; for (size_t k = 0; k < CountK; k += BlkLen) { const size_t step = std::min(BlkLen, CountK - k); @@ -1097,13 +1212,14 @@ QuantizeARow_CompInt8_avx2( // Quantize these floats const float scale = maxScalar / 127.f; - *reinterpret_cast(blob) = scale; - blob += sizeof(float); + *scale_ptr = scale; + scale_ptr++; const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; const __m256 mul = _mm256_set1_ps(inverse_scale); __m128i* dst = reinterpret_cast<__m128i*>(blob); + __m256i sum_16_epi16 = _mm256_setzero_si256(); for (size_t kk = 0; kk < step; kk += 16) { const int klen = std::min(16, (int)(step - kk)); @@ -1122,16 +1238,50 @@ QuantizeARow_CompInt8_avx2( v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); } - __m128i i_8 = convert_2_ps_to_epi8(v0, v1); - _mm_storeu_si128(dst++, i_8); + __m128i i_16_epi8 = convert_2_ps_to_epi8(v0, v1); + _mm_storeu_si128(dst++, i_16_epi8); + + // accumulate Sum(a_i) + __m256i i_16_epi16 = _mm256_cvtepi8_epi16(i_16_epi8); + sum_16_epi16 = _mm256_hadds_epi16(sum_16_epi16, i_16_epi16); } if (step < BlkLen) { memset(blob + step, 0, BlkLen - step); } + + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + *AScaledBlkSum = scale * hsum_8_epi32(sum_8_epi32); + AScaledBlkSum++; blob += BlkLen; } } +static void +SQ4BitGemmPackQuantBDataAndBlkSum( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + // TODO: always use SubBlkLen = 64 in CompInt8 + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (BlkLen == 32 && ComputeType == CompInt8) { + SubBlkLen = 64; + } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); +} + // // Kernel dispatch structure definition. // @@ -1140,6 +1290,26 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; + + d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; + + return d; +}(); + +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; @@ -1147,8 +1317,8 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx2; + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h new file mode 100644 index 000000000000..80d67806ea6e --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h @@ -0,0 +1,727 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +MLAS_FORCEINLINE __m256 +load_and_broadcast_4_scale_2(const float* scale) +{ + // 3 2 1 0 3 2 1 0 (7) + __m256 scale_2_4_ps = _mm256_broadcast_ps((__m128 const*)scale); + + // 2 1 0 0 2 1 0 0 (1) + __m256 scale_2_4_ps_shifted = _mm256_castsi256_ps( + _mm256_bslli_epi128(_mm256_castps_si256(scale_2_4_ps), 4) + ); + + // 3 2 1 0 2 1 0 0: (3) cross lane + __m256 scale_2_4_ps_permutted = _mm256_permute2f128_ps( + scale_2_4_ps_shifted, scale_2_4_ps, 0b00110000 + ); + + // in accumulate_r1_4blk_dot and accumulate_r2_4blk_dot + // _mm256_hadd_epi16 inter leaved dot sum, resulting: + // a31b31|a30b30|a11b11|a10b10|a21b21|a20b20|a01b01|a00b00 + // therefore we need weight to be: + // 3 3 1 1 2 2 0 0 (1) + return _mm256_permute_ps(scale_2_4_ps_permutted, 0b11110101); +} + +MLAS_FORCEINLINE +__m256i +load_16_epi8_as_epi16(const std::byte* ablob) +{ + const __m128i av_epi8 = _mm_lddqu_si128(reinterpret_cast(ablob)); + __m256i av_epi16 = _mm256_cvtepi8_epi16(av_epi8); + return av_epi16; +} + +MLAS_FORCEINLINE void +accumulate_r1_4blk_dot( + const __m256i& av0_32_epi8, const __m256i& av1_32_epi8, + const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, + const float* scale_a, const float* scale_b, + __m256& acc) +{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av0_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av1_32_epi8); + const __m256i sum_16_inter_leaved_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_inter_leaved_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16); + const __m256 sum_8_inter_leaved_ps = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32); + + // load 4 scales + __m256 scale_a_4_ps = load_and_broadcast_4_scale_2(scale_a); + __m256 scale_b_4_ps = load_and_broadcast_4_scale_2(scale_b); + __m256 scale_8_ps = _mm256_mul_ps(scale_a_4_ps, scale_b_4_ps); + acc = _mm256_fmadd_ps(sum_8_inter_leaved_ps, scale_8_ps, acc); +} + +MLAS_FORCEINLINE void +accumulate_r2_4blk_dot( + const __m256i& av00_32_epi8, const __m256i& av01_32_epi8, const __m256i& av10_32_epi8, const __m256i& av11_32_epi8, + const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, + const float* scale_a0, const float* scale_a1, const float* scale_b, + __m256& acc0, __m256& acc1 +) +{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_inter_leaved_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_inter_leaved_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16); + const __m256 sum_8_inter_leaved_ps = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32); + + // load 4 scales + __m256 scale_a0_4_ps = load_and_broadcast_4_scale_2(scale_a0); + __m256 scale_b_4_ps = load_and_broadcast_4_scale_2(scale_b); + __m256 scale_8_ps = _mm256_mul_ps(scale_a0_4_ps, scale_b_4_ps); + acc0 = _mm256_fmadd_ps(sum_8_inter_leaved_ps, scale_8_ps, acc0); + + const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + const __m256i sum_16_inter_leaved_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); + const __m256i sum_8_inter_leaved_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16_); + const __m256 sum_inter_leaved_ps_ = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32_); + + __m256 scale_a1_4_ps = load_and_broadcast_4_scale_2(scale_a1); + scale_8_ps = _mm256_mul_ps(scale_a1_4_ps, scale_b_4_ps); + acc1 = _mm256_fmadd_ps(sum_inter_leaved_ps_, scale_8_ps, acc1); +} + +static MLAS_FORCEINLINE __m256i +load_4b_packed_1blk_blklen16(const std::byte* QuantBDataPtr) +{ + // | 0 8 |...| 7 15 | + const __m128i bv_packed_64 = _mm_loadl_epi64(reinterpret_cast(QuantBDataPtr)); + const __m128i low_mask = _mm_set1_epi8(0xF); + const __m128i lower_8_epu8 = _mm_and_si128(bv_packed_64, low_mask); // 0~7 + const __m128i upper_8_epu8 = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bv_packed_64, 4), low_mask), 8); // 8~15 + const __m256i bv_16_epu16 = _mm256_cvtepi8_epi16(_mm_add_epi8(upper_8_epu8, lower_8_epu8)); // 0~15 + return bv_16_epu16; +} + +static MLAS_FORCEINLINE void +load_4b_packed_4blk_blklen16(const std::byte* QuantBDataPtr, __m256i& bv0_32_epi8, __m256i& bv1_32_epi8) +{ + // | 0 8 |...| 7 15 | 16 24 |...| 23 31 ||| 32 40 |...| 39 47 | 48 56 |...| 55 63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + // 0~7, 16~22, 32~39, 48~55 + __m256i bv0_32_epi8_ = _mm256_and_si256(bv_packed, low_mask); + // 8~15, 24~31, 40~47, 56~63: (1) + __m256i bv1_32_epi8_ = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8_), 4); + // 0~7, 32~39, 16~22, 48~55 <- cross lane (3) + bv0_32_epi8_ = _mm256_permute4x64_epi64(bv0_32_epi8_, 0b11011000); + // 40~47, 8~15, 56~63, 24~31 <- cross lane (3) + bv1_32_epi8_ = _mm256_permute4x64_epi64(bv1_32_epi8_, 0b01110010); + + // 0~7, 8~15, 16~22, 24~31: (1) + bv0_32_epi8 = _mm256_blend_epi32(bv0_32_epi8_, bv1_32_epi8_, 0b11001100); + + // 40~47, 32~39, 56~63, 48~55: (1) + bv1_32_epi8 = _mm256_blend_epi32(bv0_32_epi8_, bv1_32_epi8_, 0b00110011); + + // 32~39, 40~47, 48~55, 56~63: (1) + bv1_32_epi8 = _mm256_shuffle_epi32(bv1_32_epi8, 0b01001110); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk4_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + __m256i bv0_32_epi8, bv1_32_epi8; + load_4b_packed_4blk_blklen16(QuantBDataPtr, bv0_32_epi8, bv1_32_epi8); + accumulate_r2_4blk_dot(av00_32_epi8, av01_32_epi8, av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, + scale_a0, scale_a1, scale_b, acc0, acc1); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk4_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc +) +{ + __m256i bv0_32_epi8, bv1_32_epi8; + load_4b_packed_4blk_blklen16(QuantBDataPtr, bv0_32_epi8, bv1_32_epi8); + accumulate_r1_4blk_dot(av0_32_epi8, av1_32_epi8, bv0_32_epi8, bv1_32_epi8, scale_a, scale_b, acc); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk1_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale0, + const float& combined_scale1, + __m256& acc0, + __m256& acc1 +) +{ + const __m256i bv_16_epu16 = load_4b_packed_1blk_blklen16(QuantBDataPtr); + + __m256i prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av0_32_epi8); + __m256 prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); + acc0 = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale0), prod_8_ps, acc0); + + prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av1_32_epi8); + prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); + acc1 = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale1), prod_8_ps, acc1); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk1_avx2( + const __m256i& av_16_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale, + __m256& acc +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m256i bv_16_epu16 = load_4b_packed_1blk_blklen16(QuantBDataPtr); + + __m256i prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av_16_epi8); + __m256 prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); + acc = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale), prod_8_ps, acc); +} + +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + // process 4 blks of 64 4b weights a time + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 3; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); + } + + { + const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + StrideQuantBData, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + } + + { + const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 2 * StrideQuantBData, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + } + + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 3 * StrideQuantBData, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + } + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 4 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + // process 4 blks of 64 4b weights a time + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + 32; + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + 32; + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk00); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk01); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk10); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk11); + + accumulate_blklen16_r2c1blk4_avx2( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc0, acc1); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, + QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, + QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, + QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, scale_00, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, acc[3]); + } + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 4 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_blklen16_r1c1blk4_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, + QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + const __m256i av_16_epi16 = load_16_epi8_as_epi16(QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_16_epi16, QuantBDataPtr, scale_00, acc0); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE + size_t + MlasQ4Int8GemmKernelBlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc + ) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2xC1BlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen16Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen16Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h new file mode 100644 index 000000000000..af6f52090adc --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -0,0 +1,1049 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +MLAS_FORCEINLINE void +accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, + const float& combined_scale, const __m256i& one_16_epi16, __m256& acc) +{ + const __m256i dot_16_epi16 = _mm256_maddubs_epi16( + bv_32_epi8, av_32_epi8 + ); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} + +#if !defined(__GNUC__) || (__GNUC__ > 10) +MLAS_FORCEINLINE void +accumulate_1blk_dot_vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) +{ + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} +#endif + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + + // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). + const __m256i low_mask = _mm256_set1_epi8(0x0F); + //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); + // low_mask = _mm256_packus_epi16(low_mask, low_mask); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + // TODO: this (the second line below) is faster and does not keep low_mask in use. + // const __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } + { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av11_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc1 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc1); + } + } else { +#endif + //{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + // generating constant 1s is faster here. + // __m256i one = _mm256_set1_epi16(1); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + //} + //{ + const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); + const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); + const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); + + __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps_ = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc1 = _mm256_fmadd_ps(sum_ps_, scale_8_ps_, acc1); + //} +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m256& acc0 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); // 00110011 + + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } else { +#endif + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale00, + const float& combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + accumulate_1blk_dot_vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); + } else { +#endif + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); + accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale00, + __m256& acc0 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + } else { +#endif + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +MLAS_FORCEINLINE void +Q4Int8Gemm2x4x2BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale2 = 2; + const size_t StrideQuantBScale1 = 1; + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); + + { + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + } + { + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale2, acc[1], acc[NCols4 + 1]); + } + + { + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], acc[NCols4 + 2]); + } + + { + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], acc[NCols4 + 3]); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + // Col0 + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); + } + + { + // Col1 + const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale1)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + } + + { + // Col2 + const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + } + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); + + accumulate_blklen32_r2c1blk2_avx2( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXx4BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale2 = 2; + const size_t StrideQuantBScale1 = 1; + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + + { + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, + QuantAScalePtr, QuantBScalePtr, acc[0]); + } + { + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, + QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1] + ); + } + { + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, + QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2] + ); + } + { + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, + QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3] + ); + } + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, acc[3]); + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXxXBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, + QuantAScalePtr, QuantBScalePtr, acc0 + ); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE + size_t + MlasQ4Int8GemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc + ) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8Gemm2x4x2BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8Gemm2xXBlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmXx4BlkLen32Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmXxXBlkLen32Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} + +// this function is to explore larger NCols. With Avx2 it does not improve performance. +// Leave it here until the same is implemented in avx512. +template accumulator> +MLAS_FORCEINLINE +size_t +MlasQ4Int8TileGemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + // We process 32 quantized values in a batch. + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const __m256i zero = _mm256_setzero_si256(); + const __m128i low_mask = _mm_set1_epi8(0xF); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t m = 0; m < CountM; m++) { + // for each row of A, reset B pointers + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + int64_t nblk = (int64_t)(CountN)-NCols4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA + m * lda; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4]; + + acc[0] = _mm256_setzero_ps(); + acc[1] = _mm256_setzero_ps(); + acc[2] = _mm256_setzero_ps(); + acc[3] = _mm256_setzero_ps(); + + if constexpr (NCols4 == 8) { + acc[4] = _mm256_setzero_ps(); + acc[5] = _mm256_setzero_ps(); + acc[6] = _mm256_setzero_ps(); + acc[7] = _mm256_setzero_ps(); + } + + size_t k_blks_remaining = BlockCountK; + + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + const float& scale_41 = scale_a1 * (QuantBScalePtr + 4 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr, true, scale_40, acc[4]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_41, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + const float& scale_51 = scale_a1 * (QuantBScalePtr + 5 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_50, acc[5]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_51, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + const float& scale_61 = scale_a1 * (QuantBScalePtr + 6 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, false, scale_61, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + const float& scale_71 = scale_a1 * (QuantBScalePtr + 7 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, false, scale_71, acc[7]); + } + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 4 * StrideQuantBZeroPoint, true, scale_40, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 5 * StrideQuantBZeroPoint, true, scale_50, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + } + } // k_blks_remaining + + if constexpr (NCols4 == 8) { + __m128 acc_0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); + if (BiasPtr != nullptr) { + acc_0 = _mm_add_ps(acc_0, _mm_loadu_ps(BiasPtr)); + acc_1 = _mm_add_ps(acc_1, _mm_loadu_ps(BiasPtr + 4)); + } + _mm_storeu_ps(SumPtr, acc_0); + _mm_storeu_ps(SumPtr+4, acc_1); + } else { + __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + } + + // move to next NCols columns + + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + nblk -= NCols4; + } // while (nblk >= 0) + + nblk += NCols4; + for (int64_t n = 0; n < nblk; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } // m + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h new file mode 100644 index 000000000000..174ebc580904 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -0,0 +1,541 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); + __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); + __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); + + sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av11_32_epi8); + sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); + + acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); + + } else { +#endif + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + + __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); + __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); + + dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); + + acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc0 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); + } else { +#endif + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); +#if !defined(__GNUC__) || (__GNUC__ > 9) + } +#endif +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + // process 1 blks of 64 4b weights a time + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2xC1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_blklen64_r1c1blk1_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 + ); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE size_t +MlasQ4Int8GemmKernelBlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2xC1BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen64Avx2( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen64Avx2( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index b86890676070..13bd369a065b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -22,6 +22,10 @@ Module Name: #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen128.h" // // CompFp32 kernel implementation. @@ -150,18 +154,115 @@ SQ4BitGemmM1Kernel_CompFp32_avx512( // CompInt8 kernel implementation. // +MLAS_FORCEINLINE +size_t +SQ4BitGemmKernel_BlkSum_CompInt8_avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen == 16) { + MlasQ4Int8GemmKernelBlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ4Int8GemmKernelBlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 64) { + MlasQ4Int8GemmKernelBlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ4Int8GemmKernelBlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } + + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + void MLASCALL -MlasQ80BlkQuantRow_avx512( +QuantizeARow_CompInt8_avx512( size_t BlkLen, const float* A, size_t CountK, - std::byte* QuantA + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) ) { // port from MlasQ80BlkQuantRow assert(BlkLen % 16 == 0); const __m512 signBit = _mm512_set1_ps(-0.0f); + const __m256i one_16_epi16 = _mm256_set1_epi16(1); int8_t* blob = reinterpret_cast(QuantA); + float* scale_ptr = QuantAScale; for (size_t k = 0; k < CountK; k += BlkLen) { const size_t step = std::min(BlkLen, CountK - k); @@ -185,13 +286,14 @@ MlasQ80BlkQuantRow_avx512( // Quantize these floats const float scale = maxScalar / 127.f; - *reinterpret_cast(blob) = scale; - blob += sizeof(float); + *scale_ptr = scale; + scale_ptr++; const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; const __m512 mul = _mm512_set1_ps(inverse_scale); __m128i* dst = reinterpret_cast<__m128i*>(blob); + __m256i sum_16_epi16 = _mm256_setzero_si256(); for (size_t kk = 0; kk < step; kk += 16) { const size_t klen = std::min(size_t(16), step - kk); @@ -208,23 +310,46 @@ MlasQ80BlkQuantRow_avx512( // Convert int32 to int8 __m128i i0_8 = _mm512_cvtepi32_epi8(i0); _mm_storeu_si128(dst++, i0_8); + + // accumulate Sum(a_i) + __m256i i_16_epi16 = _mm256_cvtepi8_epi16(i0_8); + sum_16_epi16 = _mm256_hadds_epi16(sum_16_epi16, i_16_epi16); + } if (step < BlkLen) { memset(blob + step, 0, BlkLen - step); } + + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + *AScaledBlkSum = scale * hsum_8_epi32(sum_8_epi32); + AScaledBlkSum++; blob += BlkLen; } } -void MLASCALL -QuantizeARow_CompInt8_avx512( +static void +SQ4BitGemmPackQuantBDataAndBlkSum512( + size_t N, + size_t K, size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool ) { - MlasQ80BlkQuantRow_avx512(BlkLen, A, CountK, QuantA); + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == CompInt8) { + SubBlkLen = 128; + } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); } const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { @@ -232,6 +357,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; @@ -239,8 +365,8 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx512; + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h new file mode 100644 index 000000000000..7d9dc3685462 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h @@ -0,0 +1,1171 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +MLAS_FORCEINLINE void +accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, + const float& combined_scale, const __m256i& one_16_epi16, __m256& acc) +{ + const __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8) + ); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} + +MLAS_FORCEINLINE void +accumulate_2blk_dot( + const __m256i& av0_32_epi8, const __m256i& av1_32_epi8, + const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, + const float& combined_scale0, const float& combined_scale1, + const __m256i& one_16_epi16, + __m256& acc) +{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale_8_ps = _mm256_set_ps( + combined_scale1, combined_scale1, combined_scale0, combined_scale0, + combined_scale1, combined_scale1, combined_scale0, combined_scale0 + ); + acc = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + // TODO: will this (the second line below) be faster and not keep low_mask in use? + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + //accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); + //accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256d scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_mul( + _mm256_permute_ps(scale_a0_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), + _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0))); + + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + + + const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av10_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av11_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); + const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); + const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); + + __m256d scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps_ = _mm256_mul( + _mm256_permute_ps(scale_a1_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), + _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0))); + acc1 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc1); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + const float& combined_scale01, + const float& combined_scale10, + const float& combined_scale11, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + + // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). + // however, it is faster to generate one_16_epi16 than calling _mm256_set1_ep16(1); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); + //low_mask = _mm256_packus_epi16(low_mask, low_mask); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 + // TODO: will this (the second line below) be faster and not keep low_mask in use? + // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 + + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + + //// This (the second line below) saves one _mm256_extracti128_si256 against using _mm256_set_m128i. + ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); + //__m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + // generating constant 1s is fater here. + // __m256i one = _mm256_set1_epi16(1); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + + // performance gains 7% by calling this (accumulate_2blk_dot) instead of 2 accumulate_1blk_dot calls. + // accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one_16_epi16, acc0); + // accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one_16_epi16, acc0); + // accumulate_1blk_dot(av10_32_epi8, bv0_32_epi8, combined_scale10, one_16_epi16, acc1); + // accumulate_1blk_dot(av11_32_epi8, bv1_32_epi8, combined_scale11, one_16_epi16, acc1); + accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); + accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + const float& combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + + const int8_t zp = get_zp(true, QuantBZeroPointPtr); + const __m256i bzp = _mm256_set1_epi8(zp); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); + accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + const float& combined_scale01, + __m256& acc0) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 + // TODO: will this be faster and save a use of low_mask? + // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 + + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + + //// This saves one _mm256_extracti128_si256 against using _mm256_set_m128i. + ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); + //__m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + //accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one_16_epi16, acc0); + //accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one_16_epi16, acc0); + accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + __m256& acc0 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + + const int8_t zp = get_zp(true, QuantBZeroPointPtr); + const __m256i bzp = _mm256_set1_epi8(zp); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); +} + +template +MLAS_FORCEINLINE void +Q4Int8Gemm2x4BlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + constexpr size_t Q8Blk32Size = Q8BlkSize(BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8Blk32Size; + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + Q8Blk32Size; + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk10)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk11)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + const float& scale_a10 = Q8BlkScale(QuantABlk10); + const float& scale_a11 = Q8BlkScale(QuantABlk11); + + { + // Col0 + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + const float& scale_10 = scale_a10 * QuantBScalePtr[0]; + const float& scale_11 = scale_a11 * QuantBScalePtr[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc[0], acc[NCols4]); + } + + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[1], acc[NCols4 + 1]); + } + + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[2], acc[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[3], acc[NCols4 + 3]); + } + + // increment block pointers + QuantAPtr += Q8Blk32Size * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0 + lda)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + const float& scale_a10 = Q8BlkScale(QuantABlk0 + lda); + + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc[0], acc[NCols4]); + } + + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + } + + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + } + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + // accumulate_blklen32_r2c1_avx2 + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk10)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk11)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + const float& scale_a10 = Q8BlkScale(QuantABlk10); + const float& scale_a11 = Q8BlkScale(QuantABlk11); + + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + const float& scale_10 = scale_a10 * QuantBScalePtr[0]; + const float& scale_11 = scale_a11 * QuantBScalePtr[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc0, acc1); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0 + lda)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + const float& scale_a10 = Q8BlkScale(QuantABlk0 + lda); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc0, acc1); + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXx4BlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + { + // Col0 + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + 1 * StrideQuantBZeroPoint, scale_00, scale_01, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, acc[3]); + } + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, acc[3]); + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXxXBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE + size_t + MlasQ4Int8TileGemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc + ) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8Gemm2x4BlkLen32Avx2( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + lda, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8Gemm2xXBlkLen32Avx2( + QuantA, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + lda, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmXx4BlkLen32Avx2( + QuantA + multipleRows * lda, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + lda, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmXxXBlkLen32Avx2( + QuantA + multipleRows * lda, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + lda, + ldc); + } + + return CountM; +} + +// this function is to explore larger NCols. With Avx2 it does not improve performance. +// Leave it here until the same is implemented in avx512. +template accumulator> +MLAS_FORCEINLINE +size_t +MlasQ4Int8GemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + // We process 32 quantized values in a batch. + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const __m256i zero = _mm256_setzero_si256(); + const __m128i low_mask = _mm_set1_epi8(0xF); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t m = 0; m < CountM; m++) { + // for each row of A, reset B pointers + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + int64_t nblk = (int64_t)(CountN)-NCols4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA + m * lda; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4]; + + acc[0] = _mm256_setzero_ps(); + acc[1] = _mm256_setzero_ps(); + acc[2] = _mm256_setzero_ps(); + acc[3] = _mm256_setzero_ps(); + + if constexpr (NCols4 == 8) { + acc[4] = _mm256_setzero_ps(); + acc[5] = _mm256_setzero_ps(); + acc[6] = _mm256_setzero_ps(); + acc[7] = _mm256_setzero_ps(); + } + + size_t k_blks_remaining = BlockCountK; + + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + const float& scale_41 = scale_a1 * (QuantBScalePtr + 4 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr, true, scale_40, acc[4]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_41, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + const float& scale_51 = scale_a1 * (QuantBScalePtr + 5 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_50, acc[5]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_51, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + const float& scale_61 = scale_a1 * (QuantBScalePtr + 6 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, false, scale_61, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + const float& scale_71 = scale_a1 * (QuantBScalePtr + 7 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, false, scale_71, acc[7]); + } + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 4 * StrideQuantBZeroPoint, true, scale_40, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 5 * StrideQuantBZeroPoint, true, scale_50, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + } + } // k_blks_remaining + + if constexpr (NCols4 == 8) { + __m128 acc_0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); + if (BiasPtr != nullptr) { + acc_0 = _mm_add_ps(acc_0, _mm_loadu_ps(BiasPtr)); + acc_1 = _mm_add_ps(acc_1, _mm_loadu_ps(BiasPtr + 4)); + } + _mm_storeu_ps(SumPtr, acc_0); + _mm_storeu_ps(SumPtr+4, acc_1); + } else { + __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + } + + // move to next NCols columns + + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + nblk -= NCols4; + } // while (nblk >= 0) + + nblk += NCols4; + for (int64_t n = 0; n < nblk; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } // m + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h new file mode 100644 index 000000000000..60a887345d0e --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -0,0 +1,581 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" + +//static MLAS_FORCEINLINE __m512i +//combine_two_m256i_to_m512i(const __m256i& a, const __m256i& b) +//{ +// __m512i result = _mm512_castsi256_si512(a); +// result = _mm512_inserti64x4(result, b, 1); +// return result; +//} + +//static MLAS_FORCEINLINE void +//load_2blk_4b_packed_blklen64(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +//{ +// // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | v64 v96 | ... | v95 v127 | +// const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); +// const __m512i low_mask = _mm512_set1_epi8(0x0F); +// __m512i bv0_64_epi8_ = _mm512_and_si512(bv_packed, low_mask); // 0~31, 64~95 +// __m512i bv1_64_epi8_ = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 32~63, 96~127 +// +// // Extract lower and higher 256 bits from bv0_64_epi8 and bv1_64_epi8 +// __m256i bv0_lower = _mm512_castsi512_si256(bv0_64_epi8_); +// __m256i bv0_higher = _mm512_extracti64x4_epi64(bv0_64_epi8_, 1); +// __m256i bv1_lower = _mm512_castsi512_si256(bv1_64_epi8_); +// __m256i bv1_higher = _mm512_extracti64x4_epi64(bv1_64_epi8_, 1); +// +// // Compose new __m512i variables +// bv0_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_lower), bv1_lower, 1); +// bv1_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_higher), bv1_higher, 1); +//} + +static MLAS_FORCEINLINE void +dot_accumulate_1blk( + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float combined_scale, + __m512& acc +) +{ + __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); + __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); + __m512i t1 = _mm512_unpacklo_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i t2 = _mm512_unpackhi_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i zeros = _mm512_setzero_si512(); + const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_ternarylogic_epi32(zeros, zeros, zeros, 1), 15); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_set1_ps(combined_scale), acc); +} + +static MLAS_FORCEINLINE void +dot_accumulate_1blkvnni( + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float combined_scale, + __m512& acc +) +{ + __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); + __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(dot0_16_epi32, bv1_64_epi8, av1_64_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot1_16_epi32); + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_set1_ps(combined_scale), acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen128_r1c1blk1_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + if constexpr (vnni) { + dot_accumulate_1blkvnni( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a) * (*scale_b), acc + ); + } else { + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a) * (*scale_b), acc + ); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen128_r2c1blk1_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + if constexpr (vnni) { + dot_accumulate_1blkvnni( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a0) * (*scale_b), acc0 + ); + dot_accumulate_1blkvnni( + bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, + (*scale_a1) * (*scale_b), acc1 + ); + } else { + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a0) * (*scale_b), acc0 + ); + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, + (*scale_a1) * (*scale_b), acc1 + ); + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + // process 1 blks of 64 4b weights a time + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av00_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } // k_blks_remaining + +#if 1 + *SumPtr = _mm512_reduce_add_ps(acc[0]); + *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]); + *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]); + *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]); + *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]); + *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]); + *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]); + *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + *SumPtr += *BiasPtr; + *(SumPtr + 1) += *(BiasPtr + 1); + *(SumPtr + 2) += *(BiasPtr + 2); + *(SumPtr + 3) += *(BiasPtr + 3); + *(SumPtr + ldc) += *BiasPtr; + *(SumPtr + ldc + 1) += *(BiasPtr + 1); + *(SumPtr + ldc + 2) += *(BiasPtr + 2); + *(SumPtr + ldc + 3) += *(BiasPtr + 3); + } +#else + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); + _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); +#endif + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2xC1BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av00_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + *(SumPtr + ldc) = hsum_float_16(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr +=NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + + accumulate_blklen128_r1c1blk1_avx512( + av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 + ); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE size_t +MlasQ4Int8GemmKernelBlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2xC1BlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen128Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen128Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h new file mode 100644 index 000000000000..bb14babd6c2b --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -0,0 +1,812 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" + + + +static MLAS_FORCEINLINE void +load_4blk_4b_packed_blklen16(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +{ + // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i low_mask = _mm512_set1_epi8(0x0F); + bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 + bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk8_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen16(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0044115522663377 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); // 0~0,1~1,2~2,3~3 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); // 4~4,5~5,6~6,7~7 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00004444111155552222666633337777 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00004444111155552222666633337777 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0044115522663377 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk4_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + // TODO: load from memory + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m256 scale_a1_ps = _mm256_loadu_ps(scale_a1); // 01234567 + const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a1b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8); + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8); + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk8_avx512vnni( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen16(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0044115522663377 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0044115522663377 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0044115522663377 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk4_avx512vnni( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + // TODO: load from memory + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m256 scale_a1_ps = _mm256_loadu_ps(scale_a1); // 01234567 + const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a1b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av10_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av11_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, + acc[3], acc[NCols4 + 3] + ); + } else { + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, + acc[3], acc[NCols4 + 3] + ); + } + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } // k_blks_remaining + + __m256 acc2[NCols4 * NRows2] = { + h_add_512(acc[0]), + h_add_512(acc[1]), + h_add_512(acc[2]), + h_add_512(acc[3]), + h_add_512(acc[4]), + h_add_512(acc[5]), + h_add_512(acc[6]), + h_add_512(acc[7]) + }; + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + // Col0 + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); + } + + { + // Col1 + const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + StrideQuantBData, scale_00, scale_10, + acc2[1], acc2[NCols4 + 1]); + } + + { + // Col2 + const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 2 * StrideQuantBData, scale_00, scale_10, + acc2[2], acc2[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2( + av0_16_epi16, av1_16_epi16, QuantBDataPtr + 3 * StrideQuantBData, scale_00, scale_10, + acc2[3], acc2[NCols4 + 3]); + } + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2C1BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } else { + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc20 = h_add_512(acc0); + __m256 acc21 = h_add_512(acc1); + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc20, acc21); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc20); + *(SumPtr + ldc) = hsum_float_8(acc21); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + } else { + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + } + + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, scale_00, acc2[1]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, acc2[2]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, acc2[3]); + } + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + + } + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } else { + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc2 = h_add_512(acc0); + while (k_blks_remaining-- > 0) { + const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantAPtr); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc2); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE + size_t +MlasQ4Int8GemmKernelBlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc + ) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2C1BlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen16Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen16Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h new file mode 100644 index 000000000000..e9df6b952bd2 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -0,0 +1,852 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" + +static MLAS_FORCEINLINE void +load_4blk_4b_packed_blklen32(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +{ + // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i low_mask = _mm512_set1_epi8(0x0F); + bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 + bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 +} + +static const uint32_t index_array[16] = {0, 0, 2, 2, 0, 0, 2, 2, 1, 1, 3, 3, 1, 1, 3, 3}; + +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk4_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + // __m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); // 0~0,1~1 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); // 2~2,3~3 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk4_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + // __m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); // 0~0,1~1 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); // 2~2,3~3 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123 + const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123 + + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + // __m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8); // 0~0,1~1 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8); // 2~2,3~3 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk4_avx512vnni( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + //__m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); // 0000000011111111 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk4_avx512vnni( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + //__m512i idx = _mm512_loadu_epi8(&index_array[0]); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000000011111111 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123 + const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123 + + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av10_64_epi8); // 0000000011111111 + const __m512i dot1_32_epi16 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av11_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +MLAS_FORCEINLINE void +accumulate_1blk_dot_avx512vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) +{ + __m256i sum_8_epi32 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_avx512( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale00, + __m256& acc0 +) +{ + if constexpr (vnni) { + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + } else { + accumulate_blklen32_r1c1blk1_avx2(av00_32_epi8, QuantBDataPtr, combined_scale00, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk1_avx512( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale00, + const float& combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + if constexpr (vnni) { + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + + accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + accumulate_1blk_dot_avx512vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); + } else { + accumulate_blklen32_r2c1blk1_avx2(av00_32_epi8, av10_32_epi8, QuantBDataPtr, combined_scale00, combined_scale10, acc0, acc1); + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = PerAccuBlk4 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, + acc[3], acc[NCols4 + 3] + ); + } else { + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, + acc[3], acc[NCols4 + 3] + ); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } // k_blks_remaining + + __m256 acc2[NCols4 * NRows2] = { + h_add_512(acc[0]), + h_add_512(acc[1]), + h_add_512(acc[2]), + h_add_512(acc[3]), + h_add_512(acc[4]), + h_add_512(acc[5]), + h_add_512(acc[6]), + h_add_512(acc[7]) + }; + + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantABlk0 + lda)); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + // Col0 + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); + } + + { + // Col1 + const float scale_00 = scale_a00 * (QuantBScalePtr + 1)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 1)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, scale_10, acc2[1], acc2[NCols4 + 1]); + } + + { + // Col2 + const float scale_00 = scale_a00 * (QuantBScalePtr + 2)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[2], acc2[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[3], acc2[NCols4 + 3]); + } + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16 * NCols4; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes16; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2C1BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } else { + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + __m256 acc20 = h_add_512(acc0); + __m256 acc21 = h_add_512(acc1); + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantABlk0 + lda)); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc20, acc21); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc20); + *(SumPtr + ldc) = hsum_float_8(acc21); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } else { + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 1)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, acc2[1]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, acc2[2]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, acc2[3]); + } + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16 * NCols4; + QuantBScalePtr += NCols4; + + } + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes16; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + else { + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + __m256 acc2 = h_add_512(acc0); + while (k_blks_remaining-- > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc2); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE +size_t +MlasQ4Int8GemmKernelBlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2C1BlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen32Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen32Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h new file mode 100644 index 000000000000..2a65ac4af0c1 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -0,0 +1,840 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + +static MLAS_FORCEINLINE __m256 +h_add_512(__m512 a) +{ + return _mm256_add_ps(_mm512_castps512_ps256(a), _mm512_extractf32x8_ps(a, 1)); +} + +static MLAS_FORCEINLINE float +hsum_float_16(const __m512 x) +{ + __m256 hi = h_add_512(x); + __m128 hi128 = _mm256_extractf128_ps(hi, 1); + __m128 lo128 = _mm256_castps256_ps128(hi); + hi128 = _mm_add_ps(hi128, lo128); + hi128 = _mm_add_ps(hi128, _mm_movehl_ps(hi128, hi128)); + hi128 = _mm_add_ss(hi128, _mm_movehdup_ps(hi128)); + return _mm_cvtss_f32(hi128); +} + +static MLAS_FORCEINLINE __m512i +combine_two_m256i_to_m512i(const __m256i& a, const __m256i& b) +{ + __m512i result = _mm512_castsi256_si512(a); + result = _mm512_inserti64x4(result, b, 1); + return result; +} + +static MLAS_FORCEINLINE void +load_2blk_4b_packed_blklen64(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +{ + // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i low_mask = _mm512_set1_epi8(0x0F); + bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 + bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 + + //// Extract lower and higher 256 bits from bv0_64_epi8 and bv1_64_epi8 + //__m256i bv0_lower = _mm512_castsi512_si256(bv0_64_epi8_); + //__m256i bv0_higher = _mm512_extracti64x4_epi64(bv0_64_epi8_, 1); + //__m256i bv1_lower = _mm512_castsi512_si256(bv1_64_epi8_); + //__m256i bv1_higher = _mm512_extracti64x4_epi64(bv1_64_epi8_, 1); + + //// Compose new __m512i variables + //bv0_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_lower), bv1_lower, 1); + //bv1_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_higher), bv1_higher, 1); +} + +static MLAS_FORCEINLINE __m512i +load_1blk_4b_packed_blklen64(const std::byte* QuantBDataPtr) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16( + _mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + __m512i bv_64_epi8 = combine_two_m256i_to_m512i(bv0_32_epi8, bv1_32_epi8); + return bv_64_epi8; +} + +static MLAS_FORCEINLINE __m512i +horizontal_add_epi32(__m512i a, __m512i b) +{ + __m512i t1 = _mm512_unpacklo_epi32(a, b); + __m512i t2 = _mm512_unpackhi_epi32(a, b); + __m512i sum = _mm512_add_epi32(t1, t2); + return sum; +} + +static MLAS_FORCEINLINE __m512i +generate_ones_32_epi16() +{ + const __m512i zeros = _mm512_setzero_si512(); + return _mm512_srli_epi16(_mm512_ternarylogic_epi64(zeros, zeros, zeros, 1), 15); +} + +static MLAS_FORCEINLINE void +dot_accumulate_2blk( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float* scale_a, + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512& scale_b_16_ps, + //const __m512i& one_32_epi16, + __m512& acc) +{ + __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); + __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); + + __m512i t1 = _mm512_unpacklo_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i t2 = _mm512_unpackhi_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // sum for blk: 0 0 1 1 0 0 1 1... + __m512i one_32_epi16 = generate_ones_32_epi16(); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // sum for blk: 0 1 0 1... + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m256 scale_a_8_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m512 scale_a_16_ps = _mm512_broadcast_f32x8(scale_a_8_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); +} + +static MLAS_FORCEINLINE void +dot_accumulate_2blkvnni( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float* scale_a, + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512& scale_b_16_ps, + // const __m512i& one_32_epi16, + __m512& acc +) +{ + __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); + __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); + + __m512i t1_16_epi32 = _mm512_unpacklo_epi32(dot0_16_epi32, dot1_16_epi32); + __m512i t2_16_epi32 = _mm512_unpackhi_epi32(dot0_16_epi32, dot1_16_epi32); + __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // sum for blk: 0 0 1 1 0 0 1 1... + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m256 scale_a_8_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m512 scale_a_16_ps = _mm512_broadcast_f32x8(scale_a_8_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r2c1blk2_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); + + if constexpr (vnni) { + dot_accumulate_2blkvnni( + av00_64_epi8, av01_64_epi8, scale_a0, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc0 + ); + + dot_accumulate_2blkvnni( + av10_64_epi8, av11_64_epi8, scale_a1, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc1 + ); + } else { + dot_accumulate_2blk( + av00_64_epi8, av01_64_epi8, scale_a0, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc0 + ); + + dot_accumulate_2blk( + av10_64_epi8, av11_64_epi8, scale_a1, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc1 + ); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk2_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); + + if constexpr (vnni) { + dot_accumulate_2blkvnni( + av0_64_epi8, av1_64_epi8, scale_a, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc + ); + } else { + dot_accumulate_2blk( + av0_64_epi8, av1_64_epi8, scale_a, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc + ); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r2c1blk1_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv_64_epi8 = load_1blk_4b_packed_blklen64(QuantBDataPtr); + + const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); + + if constexpr (vnni) { + { + __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av0_64_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); + + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + + acc0 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } + + { + __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av1_64_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); + + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + + acc1 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } + } else { + const __m512i zeros = _mm512_setzero_si512(); + // const __m512i one_32_epi16_ = _mm512_andnot_epi32(zeros, zeros); + // const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_andnot_epi32(zeros, zeros), 15); + + const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_ternarylogic_epi32(zeros, zeros, zeros, 1), 15); + { + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av0_64_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + + acc0 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } + + { + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av1_64_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + + acc1 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_avx512( + const __m512i& av_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc +) +{ + __m512i bv_64_epi8 = load_1blk_4b_packed_blklen64(QuantBDataPtr); + + const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); + + if constexpr (vnni) { + __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av_32_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); + + __m128 scale_a_ps = _mm_broadcast_ss(scale_a); + __m512 scale_a_16_ps = _mm512_broadcast_f32x2(scale_a_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); + } else { + const __m512i one_32_epi16 = _mm512_set1_epi16(1); + + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av_32_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m128 scale_a_ps = _mm_broadcast_ss(scale_a); + __m512 scale_a_16_ps = _mm512_broadcast_f32x2(scale_a_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen64Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen64 = 64; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen64; + const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } // k_blks_remaining + + while (k_blks_remaining-- > 0) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + +#if 1 + *SumPtr = _mm512_reduce_add_ps(acc[0]); + *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]); + *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]); + *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]); + *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]); + *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]); + *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]); + *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + *SumPtr += *BiasPtr; + *(SumPtr + 1) += *(BiasPtr + 1); + *(SumPtr + 2) += *(BiasPtr + 2); + *(SumPtr + 3) += *(BiasPtr + 3); + *(SumPtr + ldc) += *BiasPtr; + *(SumPtr + ldc + 1) += *(BiasPtr + 1); + *(SumPtr + ldc + 2) += *(BiasPtr + 2); + *(SumPtr + ldc + 3) += *(BiasPtr + 3); + } +#else + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); +#endif + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE +Q4Int8GemmR2xC1BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + //assert(CountM % NRows2 == 0); + //assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + while (k_blks_remaining-- > 0) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + *(SumPtr + ldc) = hsum_float_16(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + //const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + //assert(CountM < NRows2); + //assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining >= PerAccuBlk2; k_blks_remaining -= PerAccuBlk2) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += PerAccuBlk2 * BlkDataSizeInBytes * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } + + while (k_blks_remaining-- > 0) { + const __m512i av_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + //assert(CountM < NRows2); + //assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + accumulate_blklen64_r1c1blk2_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + while (k_blks_remaining-- > 0) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + + accumulate_blklen64_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE size_t +MlasQ4Int8GemmKernelBlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + if (NRows2 == 2) + Q4Int8GemmR2xC4BlkLen64Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + else + Q4Int8GemmR1xC4BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + if (NRows2 == 2) + Q4Int8GemmR2xC1BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + else + Q4Int8GemmR1xC1BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc + ); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen64Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen64Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 6477a2019b21..6a5c01162c51 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -23,6 +23,10 @@ Module Name: #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_fp32.h" #include "sqnbitgemm_kernel_avx_common_int8.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen128.h" MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompFp32( @@ -146,6 +150,7 @@ void SQ4BitGemmM1Kernel_CompInt8_avx512vnni( size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -157,44 +162,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( ) { if (QuantBZeroPoint != nullptr) { - constexpr bool HasZeroPoint = true; - if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } + assert(false); } else { constexpr bool HasZeroPoint = false; if (BlkLen == 16) { @@ -212,6 +180,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( } else if (BlkLen == 32) { SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -237,52 +206,134 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( } } +MLAS_FORCEINLINE size_t -SQ4BitGemmKernel_CompInt8_avx512vnni( - size_t BlkLen, +SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni( + const size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, - const std::byte* QuantBZeroPoint, + const std::byte* /*QuantBZeroPoint*/, float* C, size_t CountM, size_t CountN, - size_t CountK, + size_t /*CountK*/, size_t BlockCountK, + const float* Bias, size_t ldc, - const float* Bias + const float* ABlockSum, + const float* QuantBBlkSum ) { - MLAS_UNREFERENCED_PARAMETER(ldc); - - if (CountM == 0) { - return 0; + if (BlkLen == 16) { + MlasQ4Int8GemmKernelBlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ4Int8GemmKernelBlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 64) { + MlasQ4Int8GemmKernelBlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ4Int8GemmKernelBlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); } - SQ4BitGemmM1Kernel_CompInt8_avx512vnni( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockCountK, - Bias - ); + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; - return 1; + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; } void MLASCALL -MlasQ80BlkQuantRow_avx512( +QuantizeARow_CompInt8_avx512( size_t BlkLen, const float* A, size_t CountK, - std::byte* QuantA + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) ); +static void +SQ4BitGemmPackQuantBDataAndBlkSum512vnni( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == CompInt8) { + SubBlkLen = 128; + } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); +} + // // Kernel dispatch structure definition. // @@ -291,6 +342,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512vnni; d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; @@ -298,8 +350,8 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx512vnni; - d.QuantizeARow_CompInt8 = MlasQ80BlkQuantRow_avx512; + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 706e08fc467b..177f5518bb89 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -14,13 +14,24 @@ SQ4BitGemmPackQuantBDataSize( MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType - constexpr size_t BlkBitWidth = 4; - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; + if (ComputeType == CompInt8) { + size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t ScaleSize = N * BlockCountK * sizeof(float); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + + // _mm256_load_si256 requires alignment on a 32-byte boundary + constexpr size_t PackedQuantBDataAlignment = 32; + PackedQuantBDataSize += PackedQuantBDataAlignment - 1; + constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); + BlkSumSize += BlkSumAlignment - 1; + + return PackedQuantBDataSize + ScaleSize + BlkSumSize; + } else { + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; + } } static void @@ -100,6 +111,216 @@ SQ4BitGemmPackQuantBData( ); } +static size_t +GetContinueLayoutOffsetSubBlk(size_t N, const size_t n, const size_t SubOrBlkCountK, const size_t k_sub_or_blk) +{ + size_t T = n / 4, t = n % 4; + bool te = T == N / 4; + size_t scale_dst_offset = T * 4 * SubOrBlkCountK; + if (te) { + scale_dst_offset += t * SubOrBlkCountK + k_sub_or_blk; + } else { + scale_dst_offset += k_sub_or_blk * 4 + t; + } + return scale_dst_offset; +} + +static size_t +GetContinueLayoutOffsetBlkInSubBlk(size_t N, const size_t n, const size_t BlockCountK, const size_t k_blk, const int blks_per_sub) +{ + size_t T = n / 4, t = n % 4, k_subblk = k_blk / blks_per_sub, b = k_blk % blks_per_sub; + bool te = T == N / 4, be = k_subblk == BlockCountK / blks_per_sub; + size_t scale_dst_offset = T * 4 * BlockCountK; + if (te) { + scale_dst_offset += t * BlockCountK + k_blk; + } else { + scale_dst_offset += k_subblk * blks_per_sub * 4; + if (be) { + scale_dst_offset += b * 4 + t; + } else { + scale_dst_offset += t * blks_per_sub + b; + } + } + return scale_dst_offset; +} + +static void +PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t BlockCountK, + const size_t BlkLen, + const size_t SubBlkLen) +{ + constexpr size_t BlkBitWidth = 4; + const size_t BlkBytePairCount = BlkLen / 4; + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + const size_t SubBlkCountK = MlasDivRoundup(BlockCountK * BlkLen, SubBlkLen); + const size_t Iterations = N * SubBlkCountK; // one iteration per sub block + + // for avx2 + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // for the remaining blk, it shall be: + // dst blklen32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // dst blklen16: | v0 v8 | v1 v9 | v2 v11 | v3 v12 | v4 v13 | v5 v14 | v6 v15 | v7 v16 | + + // for avx512 + // dst: | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + // for the remaining blk, it shall be: + // dst blklen64: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // dst blklen32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // dst blklen16: | v0 v8 | v1 v9 | v2 v11 | v3 v12 | v4 v13 | v5 v14 | v6 v15 | v7 v16 | + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / SubBlkCountK; + const size_t k_subblk = tid % SubBlkCountK; + + const size_t src_data_offset = n * BlockCountK * BlkDataSize + k_subblk * SubBlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + src_data_offset; + + size_t PackBytePairCount = SubBlkBytePairCount; + size_t PackDataSize = SubBlkDataSize; + + auto pack_subblk = []( + const std::byte* QuantBData, std::byte* PackedQuantBData, + size_t pack_byte_pair_count, size_t pack_data_size) { + for (size_t byte_pair_idx = 0; byte_pair_idx < pack_byte_pair_count; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + pack_data_size / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } }; + + if (SubBlkLen > BlkLen && k_subblk == SubBlkCountK - 1 && + SubBlkLen * SubBlkCountK > BlkLen * BlockCountK) { + // this is the last subblk of the column. check if it extends out of the + // BlockCountK. If it does, we shall pack per blocks so that can compute + // on each block instead of each subblk. + PackBytePairCount = BlkBytePairCount; + PackDataSize = BlkDataSize; + const size_t k_blks_remaining = BlockCountK - (SubBlkCountK - 1) * SubBlkLen / BlkLen; + for (size_t k = 0; k < k_blks_remaining; k++) { + const size_t k_blk = k_subblk * SubBlkLen / BlkLen + k; + if (BlkLen == 16) { + // not to do the compute order layout yet + std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); + } else if (BlkLen >= SubBlkLen) { + // shall not reach here with avx2 + assert(SubBlkLen == 128); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData, PackBytePairCount, PackDataSize); + } + } + } else { + if (BlkLen == 16) { + // not to do the compute order layout yet + std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } else if (BlkLen >= SubBlkLen) { + const size_t dst_data_offset = GetContinueLayoutOffsetSubBlk(N, n, SubBlkCountK, k_subblk); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * SubBlkDataSize; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const size_t k_blk = k_subblk * blks_per_sub; + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } + } + } + ); +} + +//#include + +static void +ComputePackBlkSum( + size_t BlkLen, + size_t SubBlkLen, + size_t N, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t BlockCountK) +{ + std::vector QuantBScaleBeginCopy(N * BlockCountK); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t src_blk_offset = n * BlockCountK + k_blk; + const float& QuantBScale = QuantBScaleBeginCopy[src_blk_offset]; + uint8_t zp = 8; + if (QuantBZPBegin) { + size_t ZPCountK = MlasDivRoundup(BlockCountK, 2); + size_t src_zp_offset = ZPCountK * n + k_blk / 2; + bool low_zp = k_blk % 2 == 0; + const std::byte* QuantBZP = QuantBZPBegin + src_zp_offset; + const std::byte low_mask{0X0F}; + zp = (uint8_t)(low_zp ? ((*QuantBZP) & low_mask) : ((*QuantBZP) >> 4)); + } + + // BlockSum is a width 16 row major matrix + const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset) = -QuantBScale * zp; + if (BlkLen == 16) { // TODO + + } else if (BlkLen >= SubBlkLen) { + const size_t scale_dst_offset = GetContinueLayoutOffsetSubBlk(N, n, BlockCountK, k_blk); + *(QuantBScaleBegin + scale_dst_offset) = QuantBScale; + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + size_t scale_dst_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); + *(QuantBScaleBegin + scale_dst_offset) = QuantBScale; + } + } + ); +} + +static void +PackQuantBDataAndBlkSum( + size_t N, + size_t BlockCountK, + size_t BlkLen, + size_t SubBlkLen, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool +) +{ + if (QuantBDataBegin) { + PackQuantB(QuantBDataBegin, packed_quant_b.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + } + + if (QuantBScaleBegin) { + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, packed_quant_b.PackedQuantBScale); + } + + if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) { + ComputePackBlkSum(BlkLen, SubBlkLen, N, packed_quant_b.PackedQuantBScale, QuantBZPBegin, packed_quant_b.QuantBBlkSum, ThreadPool, BlockCountK); + } +} + // // Workspace size calculation function implementation. // @@ -119,7 +340,8 @@ SQ4BitGemmPerGemmWorkspaceSize( case CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + // QuantData + Scale + BlkSum + const size_t PerGemmWorkspaceSize = M * BlockCountK * (Q8BlkSize(BlkLen) + sizeof(float)); return PerGemmWorkspaceSize; } default: { @@ -288,6 +510,20 @@ load_and_mul_sum_s8_quads_with_zp_avx2( acc0 = _mm256_fmadd_ps(sum_ps, scale0, acc0); } +template +void MLAS_FORCEINLINE +get_2_zps(const std::byte* QuantBZeroPointPtr, int8_t& zp0, int8_t& zp1) +{ + if constexpr (HasZeroPoint) { + zp0 = std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}); + zp1 = std::to_integer((*QuantBZeroPointPtr) >> 4); + } else { + zp0 = 8; + zp1 = 8; + (void)QuantBZeroPointPtr; + } +} + template int8_t MLAS_FORCEINLINE get_zp(bool is_lower_half_byte_zp, const std::byte* QuantBZeroPointPtr) @@ -375,7 +611,7 @@ FoldAccumulators(const __m256& acc0, const __m256& acc1, const __m256& acc2, con return acc_y; } -static inline float +static MLAS_FORCEINLINE float hsum_float_8(const __m256 x) { __m128 res = _mm256_extractf128_ps(x, 1); @@ -417,4 +653,27 @@ FoldAccumulators(const __m512& acc0, const __m512& acc1, const __m512& acc2, con _mm256_add_ps(_mm512_extractf32x8_ps(acc_lo0123, 0), _mm512_extractf32x8_ps(acc_lo0123, 1)); return _mm_add_ps(_mm256_extractf32x4_ps(acc_y, 0), _mm256_extractf32x4_ps(acc_y, 1)); } + +static MLAS_FORCEINLINE __m128i +convert_2_ps_to_epi8(__m256 v0, __m256 v1) +{ + __m256i v0_8_epi32 = _mm256_cvtps_epi32(v0); + __m256i v1_8_epi32 = _mm256_cvtps_epi32(v1); + + __m128i v0_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v0_8_epi32, 0), _mm256_extractf128_si256(v0_8_epi32, 1)); + __m128i v1_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v1_8_epi32, 0), _mm256_extractf128_si256(v1_8_epi32, 1)); + + return _mm_packs_epi16(v0_8_epi16, v1_8_epi16); +} + +// horizontally add 8 int32_t +static MLAS_FORCEINLINE int +hsum_8_epi32(const __m256i a_8_epi32) +{ + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a_8_epi32), _mm256_extractf128_si256(a_8_epi32, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} } // namespace diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h index 250ffeacd7c2..895ce6cd091c 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h @@ -7,20 +7,6 @@ #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_q8_block.h" -void -SQ4BitGemmM1Kernel_CompInt8_avx2( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -); - template MLAS_FORCEINLINE void ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen16( @@ -240,6 +226,7 @@ template accumulator> void SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -273,6 +260,7 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( int64_t nblk = (int64_t)(CountN)-4; while (nblk >= 0) { const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; @@ -286,14 +274,14 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= 2) { const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); - const float& scale_a0 = Q8BlkScale(QuantABlk0); - const float& scale_a1 = Q8BlkScale(QuantABlk1); + const float& scale_a0 = *QuantAScalePtr; + const float& scale_a1 = *(QuantAScalePtr + 1); // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; @@ -320,7 +308,8 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc3); // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; QuantBDataPtr += 16 * 2; QuantBScalePtr += 2; if constexpr (HasZeroPoint) { @@ -331,9 +320,9 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( if (k_blks_remaining > 0) { // load A const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a0 = *QuantAScalePtr; // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; @@ -374,6 +363,7 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( nblk += NCols; for (int64_t n = 0; n < nblk; n++) { const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; @@ -383,14 +373,14 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= 2) { const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); - const float& scale_a0 = Q8BlkScale(QuantABlk0); - const float& scale_a1 = Q8BlkScale(QuantABlk1); + const float& scale_a0 = *QuantAScalePtr; + const float& scale_a1 = *(QuantAScalePtr + 1); // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; @@ -399,7 +389,8 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; QuantBDataPtr += 16 * 2; QuantBScalePtr += 2; if constexpr (HasZeroPoint) { @@ -410,9 +401,9 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( if (k_blks_remaining > 0) { // load A const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a0 = *QuantAScalePtr; // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h new file mode 100644 index 000000000000..45c3963365e6 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -0,0 +1,759 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_zp_avx2( + const __m256i& av_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale, + const std::byte* QuantBZeroPointPtr, + __m256& acc, + const __m256i& low_mask +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(low_mask, bv_32_epi8); + + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + const __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); + } else { +#endif + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + const __m256i dot_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_zp_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + const std::byte* QuantBZeroPointPtr, + __m256& acc0, + const __m256i& low_mask +) +{ + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + { + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + + { + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + } else { +#endif + { + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); + __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) + ); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + + { + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); + __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) + ); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_zp_is_8_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc0, + const __m256i& low_mask, + const __m256i& bzp8 +) +{ + // accumulate_blklen32_r1c1blk2_zp_is_8_avx2 is much faster than + // accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2: + // BlkBitWidth:4/BlkLen:32/M:1/N:2560/K:2560/Threads:8/Symmetric:1/HasBias:0/ComputeType:4 + // 36591 vs 40270 ns (the main is 51836 ns). both are not as good as main with genai. + // TODO: consolidate with accumulate_blklen32_r1c1blk2_avx2 using a zp8 template option + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 + + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } else { +#endif + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps( + _mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0) + ); + + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const __m256& scale_a0_8_ps, + const __m256& scale_a1_8_ps, + const std::byte* QuantBDataPtr, + const float* scale_b, + __m256& acc0, + const __m256i& low_mask, + const __m256i& bzp8 +) +{ + // TODO: consolidate with accumulate_blklen32_r1c1blk2_avx2 using a zp8 template option + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 + + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) { + { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + } else { +#endif + { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } +#if !defined(__GNUC__) || (__GNUC__ > 10) + } +#endif +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C4BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountN % NCols4 == 0); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + //const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale2 = 2; + const size_t StrideQuantBScale1 = 1; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + //const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); + //const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen32))); + + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc[0], low_mask, bzp8); + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + StrideQuantBData, QuantBScalePtr + StrideQuantBScale, acc[1], low_mask, bzp8); + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 2 * StrideQuantBData, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], low_mask, bzp8); + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 3 * StrideQuantBData, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], low_mask, bzp8); + if constexpr (HasZeroPoint) { + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc[0], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); + + } else { + const __m256i bzp8 = _mm256_set1_epi8(8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], low_mask, bzp8); + } + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const float& scale_a00 = *QuantAScalePtr; + { + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc[0], low_mask); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C1BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountN < NCols4); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + [[maybe_unused]] const __m256i bzp8 = _mm256_set1_epi8(8); + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + //const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); + //const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen32))); + + if constexpr (HasZeroPoint) { + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc0, low_mask); + } else { + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0, low_mask, bzp8); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc0, low_mask); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template +MLAS_FORCEINLINE +void +MlasQ4Int8GemmM1KernelBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias + ) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleCols > 0) { + Q4Int8GemmM1C4BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + multipleCols, + BlockCountK, + Bias); + } + + if (remainingCols > 0) { + Q4Int8GemmM1C1BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleCols, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr); + } +} + +//#define SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout 1 +void SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + // port from neon implementation + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 32; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout +#else + constexpr bool HasZeroPoint = false; +#endif + + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + //const size_t StrideQuantBScale = BlockCountK; + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + const __m256i low_mask = _mm256_set1_epi8(0x0F); + const __m256i bzp8 = _mm256_set1_epi8(8); + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); + (void)StrideQuantBZeroPoint; +#else + const __m256i zero = _mm256_setzero_si256(); + const __m128i low_mask = _mm_set1_epi8(0xF); +#endif + const size_t NCols = 4; + constexpr size_t StrideQuantBScale2 = 2; + constexpr size_t StrideQuantBScale1 = 1; + + int64_t nblk = (int64_t)(CountN)-4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + (void)QuantBZeroPointPtr; +#endif + __m256 + acc0 = _mm256_setzero_ps(), + acc1 = _mm256_setzero_ps(), + acc2 = _mm256_setzero_ps(), + acc3 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); + const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen))); +#else + const float& scale_a0 = QuantAScalePtr[0]; + const float& scale_a1 = QuantAScalePtr[1]; +#endif + + // Col0 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc0, low_mask, bzp8); +#else + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); +#endif + + // Col1 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + StrideQuantBData, QuantBScalePtr + StrideQuantBScale2, acc1, low_mask, bzp8); +#else + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale2)[0]; + const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale2)[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc1); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc1); +#endif + + // Col2 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 2 * StrideQuantBData, QuantBScalePtr + 2 * StrideQuantBScale2, acc2, low_mask, bzp8); +#else + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale2)[0]; + const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale2)[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc2); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc2); +#endif + // Col3 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 3 * StrideQuantBData, QuantBScalePtr + 3 * StrideQuantBScale2, acc3, low_mask, bzp8); +#else + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale2)[0]; + const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale2)[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc3); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc3); +#endif + // increment block pointers + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2 * NCols; + } + + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a0 = *QuantAScalePtr; + + // Col0 + const float& scale_0 = scale_a0 * QuantBScalePtr[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr, scale_0, acc0, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_0, acc0); +#endif + + // Col1 + const float& scale_1 = scale_a0 * (QuantBScalePtr + StrideQuantBScale1)[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + StrideQuantBData, scale_1, acc1, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_1, acc1); +#endif + + // Col2 + const float& scale_2 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_2, acc2, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_2, acc2); +#endif + + // Col3 + const float& scale_3 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_3, acc3, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_3, acc3); +#endif + } + + __m128 acc_x = FoldAccumulators(acc0, acc1, acc2, acc3); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + + // move to next NCols columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + nblk -= NCols; + } + + nblk += NCols; + for (int64_t n = 0; n < nblk; n++) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + (void)QuantBZeroPointPtr; +#endif + __m256 acc0 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantABlk0)); + const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantABlk1)); +#else + const float& scale_a0 = QuantAScalePtr[0]; + const float& scale_a1 = QuantAScalePtr[1]; +#endif + + // Col0 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc0, low_mask, bzp8); +#else + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); +#endif + // increment block pointers + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2; + } + + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a0 = *QuantAScalePtr; + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr, scale_00, acc0, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); +#endif + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h new file mode 100644 index 000000000000..e9c3812bde89 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h @@ -0,0 +1,312 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_zp_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + const std::byte* QuantBZeroPointPtr, + const bool is_lower_half_byte_zp, + __m256& acc0, + const __m256i& low_mask +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + const __m256i bzp8 = _mm256_set1_epi8(get_zp(is_lower_half_byte_zp, QuantBZeroPointPtr)); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8)); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8)); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); +} + +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_zp_is_8_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc0, + const __m256i& low_mask, + const __m256i& bzp8 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8)); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8)); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t SubblkLen64 = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen64; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountN % NCols4 == 0); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + const size_t StrideQuantBData1 = 1 * SubblkDataSizeInBytes; + const size_t StrideQuantBScale1 = 1; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + [[maybe_unused]] const bool is_lower_half_byte_zp = (k % 2) == 0; + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + if constexpr (HasZeroPoint) { + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, is_lower_half_byte_zp, acc[0], low_mask); + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale1, QuantBZeroPointPtr + StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[1], low_mask); + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale1, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[2], low_mask); + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale1, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[3], low_mask); + } else { + const __m256i bzp8 = _mm256_set1_epi8(8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0], low_mask, bzp8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale1, acc[1], low_mask, bzp8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale1, acc[2], low_mask, bzp8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale1, acc[3], low_mask, bzp8); + } + + // increment block pointers + QuantAPtr += SubblkLen64; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += k % 2; + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + assert(CountN < NCols4); + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + [[maybe_unused]] const __m256i bzp8 = _mm256_set1_epi8(8); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + [[maybe_unused]] const bool is_lower_half_byte_zp = (k % 2) == 0; + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + if constexpr (HasZeroPoint) { + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, is_lower_half_byte_zp, acc0, low_mask); + } else { + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0, low_mask, bzp8); + } + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += k % 2; + } + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template +MLAS_FORCEINLINE void +MlasQ4Int8GemmKernelBlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleCols > 0) { + Q4Int8GemmM1C4BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + multipleCols, + BlockCountK, + Bias); + } + + if (remainingCols > 0) { + Q4Int8GemmM1C1BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleCols, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr); + } +} diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index 08066f030a38..ff8943de7967 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -50,7 +50,8 @@ void MergeWeights(const T* q, const T* k, const T* v, std::vector& result, in // Merge 2-D weights (q, k and v) by concatenating them row by row. template -void MergeMatMulWeights(const T* q_weight, const T* k_weight, const T* v_weight, std::vector& result, int64_t hidden_size) { +void MergeMatMulWeights(const T* q_weight, const T* k_weight, const T* v_weight, + std::vector& result, int64_t hidden_size) { const T* q = q_weight; const T* k = k_weight; const T* v = v_weight; @@ -144,7 +145,8 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size, return graph_utils::AddInitializer(graph, initializer); } -static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type, const logging::Logger& logger) { +static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type, + const logging::Logger& logger) { // Validate mask input shape (batch_size, sequence_length) and data type. // Note that batch_size and sequence_length could be symbolic. const TensorShapeProto* mask_shape = mask_input->Shape(); @@ -208,9 +210,11 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& node = *p_node; ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && // Add node.GetOutputEdgesCount() == 5/6 for distilbert + // Add node.GetOutputEdgesCount() == 5/6 for distilbert + if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && graph_utils::IsSupportedOptypeVersionAndDomain(node, "LayerNormalization", {1, 17}, kOnnxDomain) && - graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && + node.InputDefs().size() > 2 && node.InputDefs()[2]->Exists()) { // Bias is an optional input for LayerNorm // Get hidden size from layer norm bias tensor shape. const NodeArg& layer_norm_bias = *(node.InputDefs()[2]); if (!optimizer_utils::IsShapeKnownOnAllDims(layer_norm_bias, 1)) { @@ -242,8 +246,10 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, fused_count++; modified = true; } - } else if (reshape_count == 1 && (shape_count == 1 || shape_count == 3) && (static_cast(reshape_count) + shape_count) == node.GetOutputEdgesCount()) { // GPT - if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, shape_count == 1, logger)) { + } else if (reshape_count == 1 && (shape_count == 1 || shape_count == 3) && + (static_cast(reshape_count) + shape_count) == node.GetOutputEdgesCount()) { // GPT + if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, shape_count == 1, + logger)) { fused_count++; modified = true; } @@ -301,7 +307,8 @@ static bool FuseSubGraphQKImpl(Node& layer_norm, return false; } - if (!AttentionFusionHelper::CheckNodesInPathQ(graph, pivot_nodes[1].get(), q_reshape, q_transpose, num_heads, head_size, logger)) { + if (!AttentionFusionHelper::CheckNodesInPathQ(graph, pivot_nodes[1].get(), + q_reshape, q_transpose, num_heads, head_size, logger)) { DEBUG_LOG("CheckNodesInPathQ returns false"); return false; } @@ -365,7 +372,8 @@ static bool FuseSubGraphQKImpl(Node& layer_norm, } // Now everything is ready, we will start fusing subgraph. - NodeArg* mask_int32 = ConvertMaskToInt32(graph, mask_input, mask_int32_map, layer_norm.GetExecutionProviderType(), logger); + NodeArg* mask_int32 = ConvertMaskToInt32(graph, mask_input, mask_int32_map, layer_norm.GetExecutionProviderType(), + logger); if (nullptr == mask_int32) { DEBUG_LOG("Failed to convert mask to int32"); return false; @@ -438,7 +446,8 @@ static bool FuseSubGraphQK(Node& layer_norm, } std::vector nodes_to_remove; - if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size, + if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, + mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size, num_heads, head_size, mask_nodes.mask_filter_value, logger)) { return false; } @@ -529,7 +538,8 @@ static bool FuseSubGraphQKDistilBert(Node& layer_norm, } std::vector nodes_to_remove; - if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size, + if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, + mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size, num_heads, head_size, mask_nodes.mask_filter_value, logger)) { return false; } @@ -615,7 +625,12 @@ After Fusion: | | Add */ -bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer_norm, Graph& graph, int64_t hidden_size, std::map& mask_int32_map, const logging::Logger& logger) { +bool AttentionFusion::FuseSubGraph(Node& layer_norm, + const Node& add_after_layer_norm, + Graph& graph, + int64_t hidden_size, + std::map& mask_int32_map, + const logging::Logger& logger) { std::vector parent_path{ {0, 0, "Add", {7, 13}, kOnnxDomain}, {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}, @@ -657,7 +672,9 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer int64_t num_heads = 0; // will be updated in CheckNodesInPathV int64_t head_size = 0; // will be updated in CheckNodesInPathV NodeIndex record_node_idx = 0; // will be updated in CheckNodesInPathV if it's distilbert model - if (!AttentionFusionHelper::CheckNodesInPathV(graph, reshape, transpose, qkv_matmul, v_transpose, v_reshape, num_heads, head_size, hidden_size, record_node_idx, logger)) { + if (!AttentionFusionHelper::CheckNodesInPathV(graph, reshape, transpose, + qkv_matmul, v_transpose, v_reshape, num_heads, + head_size, hidden_size, record_node_idx, logger)) { DEBUG_LOG("CheckNodesInPathV return false"); return false; } @@ -672,7 +689,8 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer } // store parent path - std::vector> parent_path_nodes{reshape, transpose, qkv_matmul, v_transpose, v_reshape, v_add, v_matmul}; + std::vector> parent_path_nodes{ + reshape, transpose, qkv_matmul, v_transpose, v_reshape, v_add, v_matmul}; // Find mask nodes: Unsqueeze -> Unsqueeze -> (Cast) -> Sub -> Mul -> Add -> Softmax --> [MatMul] // The "Cast" node in parentheses is optional. @@ -681,10 +699,13 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer if (AttentionFusionHelper::MatchInputMaskSubgraph(graph, qkv_matmul, mask_nodes, logger, false)) { NodeArg* mask_input = graph.GetNode(mask_nodes.unsqueeze_1->Index())->MutableInputDefs()[0]; - return FuseSubGraphQK(layer_norm, graph, mask_nodes, mask_input, parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger); - } else if (AttentionFusionHelper::MatchInputMaskSubgraph(graph, layer_norm, qkv_matmul, mask_nodes_distilbert, record_node_idx, logger)) { + return FuseSubGraphQK(layer_norm, graph, mask_nodes, mask_input, + parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger); + } else if (AttentionFusionHelper::MatchInputMaskSubgraph(graph, layer_norm, qkv_matmul, + mask_nodes_distilbert, record_node_idx, logger)) { NodeArg* mask_input = graph.GetNode(mask_nodes_distilbert.equal->Index())->MutableInputDefs()[0]; - return FuseSubGraphQKDistilBert(layer_norm, graph, mask_nodes_distilbert, mask_input, parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger); + return FuseSubGraphQKDistilBert(layer_norm, graph, mask_nodes_distilbert, mask_input, + parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger); } else { DEBUG_LOG("Failed in match input mask subgraph"); return false; diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index 12746ad53123..c7d2d95e6121 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -58,16 +58,11 @@ bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) { } bool ConvFusionDataTypeCheck(const Node& conv_node) { - // TODO(hasesh): The CPU and CUDA EP only support float type for the Conv+Activation + // TODO(hasesh): The CPU EP only supports float type for the Conv+Activation // and the Conv+Add+Relu fusions. // Assess the support level for the other compatible EPs and if they also // only support float, remove the EP check altogether. const std::string_view node_ep = conv_node.GetExecutionProviderType(); - if (node_ep == kCudaExecutionProvider) { - if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) { - return false; - } - } if (node_ep == kCpuExecutionProvider) { #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT) && @@ -120,7 +115,9 @@ class ConvActivationSelector : public NodeSelector { } // check EP type and activation - if (node_ep == kCudaExecutionProvider || node_ep == kRocmExecutionProvider) { + if (node_ep == kCudaExecutionProvider) { + return std::nullopt; + } else if (node_ep == kRocmExecutionProvider) { if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) { return std::nullopt; } @@ -142,43 +139,6 @@ class ConvActivationSelector : public NodeSelector { } }; -class ConvAddRelu : public NodeSelector { - public: - ConvAddRelu() = default; - - std::optional Select(const GraphViewer& graph_viewer, const Node& node) const override { - const std::string_view node_ep = node.GetExecutionProviderType(); - // only for CUDA EP - if (node_ep != kCudaExecutionProvider) { - return std::nullopt; - } - - if (!ConvFusionDataTypeCheck(node)) { - return std::nullopt; - } - - const auto* add_node = GetLoneConsumerNode(graph_viewer, node); - if (!add_node || - !graph_utils::IsSupportedOptypeVersionAndDomain(*add_node, "Add", {6, 7, 13, 14}) || - add_node->GetExecutionProviderType() != node_ep) { - return std::nullopt; - } - - const auto* relu_node = GetLoneConsumerNode(graph_viewer, *add_node); - if (!relu_node || - !graph_utils::IsSupportedOptypeVersionAndDomain(*relu_node, "Relu", {6, 13, 14}) || - relu_node->GetExecutionProviderType() != node_ep) { - return std::nullopt; - } - - NodesToOptimizeIndicesBuilder builder{}; - builder.target_node = node.Index(); - builder.output_nodes = {add_node->Index(), - relu_node->Index()}; - return builder.Build(); - } -}; - } // namespace selectors #endif // !defined(ORT_MINIMAL_BUILD) @@ -304,22 +264,9 @@ void RegisterConvActivationFusionRules(SelectorActionRegistry& registry) { #endif } -void RegisterConvAddReluFusionRules(SelectorActionRegistry& registry) { - const auto name = "ConvAddRelu"; - auto action = std::make_unique(); -#if !defined(ORT_MINIMAL_BUILD) - auto selector = std::make_unique(); - registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}}, - std::move(selector), std::move(action)); -#else - registry.RegisterAction(name, std::move(action)); -#endif -} - SelectorActionRegistry CreateSelectorActionRegistry() { SelectorActionRegistry registry{}; RegisterConvActivationFusionRules(registry); - RegisterConvAddReluFusionRules(registry); return registry; } diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index ab1dbaea7b7f..08284e67277e 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -189,7 +189,8 @@ InlinedVector> GenerateTransformers( const SessionOptions& session_options, const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ const InlinedHashSet& rules_and_transformers_to_disable, - [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool) { + [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { InlinedVector> transformers; const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; @@ -281,6 +282,11 @@ InlinedVector> GenerateTransformers( onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kDmlExecutionProvider}; + const InlinedHashSet cpu_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider}; const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, @@ -309,14 +315,15 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, SatApplyContextVariant{}, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool)); + intra_op_thread_pool, + p_buffered_tensors)); } transformers.emplace_back(std::make_unique(cpu_ep)); transformers.emplace_back(std::make_unique(cpu_dml_eps)); transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); + transformers.emplace_back(std::make_unique(cpu_rocm_acl_armnn_js_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); @@ -419,7 +426,8 @@ InlinedVector> GenerateTransformersForMinimalB const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, const InlinedHashSet& rules_and_transformers_to_disable, - [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool) { + [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { InlinedVector> transformers; const bool saving = std::holds_alternative(apply_context); @@ -444,7 +452,8 @@ InlinedVector> GenerateTransformersForMinimalB transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool)); + intra_op_thread_pool, + p_buffered_tensors)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); diff --git a/onnxruntime/core/optimizer/pad_fusion.cc b/onnxruntime/core/optimizer/pad_fusion.cc index a1c7f8de9e6f..e266946b0d9e 100644 --- a/onnxruntime/core/optimizer/pad_fusion.cc +++ b/onnxruntime/core/optimizer/pad_fusion.cc @@ -12,7 +12,7 @@ namespace onnxruntime { * It matches following pattern: * Pad * | - * Conv/MaxPool + * Conv/MaxPool/AveragePool */ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { // if Pad has input axis, don't fuse it. @@ -28,6 +28,7 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log const Node& child_node = *node.OutputNodesBegin(); if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) && + !graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {1, 7, 10, 11, 19}) && !graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) { return false; } diff --git a/onnxruntime/core/optimizer/pad_fusion.h b/onnxruntime/core/optimizer/pad_fusion.h index a1b6978a83d1..ca05d219b7e2 100644 --- a/onnxruntime/core/optimizer/pad_fusion.h +++ b/onnxruntime/core/optimizer/pad_fusion.h @@ -8,7 +8,7 @@ namespace onnxruntime { /* * This fusion submerges a Pad operator to it's child - * Conv or MaxPool operator, if and only if PadFusion::SatisfyCondition() + * Conv or MaxPool or AveragePool operator, if and only if PadFusion::SatisfyCondition() * is true. */ class PadFusion : public RewriteRule { diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index f0e76312d6e0..7b518947138a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -3,8 +3,13 @@ #include "core/optimizer/qdq_transformer/qdq_propagation.h" +#include #include +#include +#include +#include +#include "core/common/inlined_containers_fwd.h" #include "core/graph/extended_graph_edge.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" @@ -17,39 +22,147 @@ namespace onnxruntime { namespace { bool CanNodePropagate(const Node& node) { return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {12}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13, 14, 19}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Squeeze", {1, 11, 13}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Unsqueeze", {1, 11, 13}); + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13, 14, 19, 21}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13, 21}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Squeeze", {1, 11, 13, 21}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Unsqueeze", {1, 11, 13, 21}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}); } -// convert this: src_node -> dst_node -// to this: src_node -> Q -> DQ -> dst_node -// assumptions: -// 1. insertion_edge is valid - node indexes refer to valid nodes, arg name refers to a valid NodeArg, and it -// corresponds to an actual graph relationship -// 2. scale_initializer_nodearg and zp_initializer_nodearg_ptr (if not null) are constant initializers -Status InsertQDQPair(Graph& graph, const ExtendedGraphEdge& insertion_edge, - NodeArg& scale_initializer_nodearg, NodeArg* zp_initializer_nodearg_ptr, - const std::string& qdq_domain, const logging::Logger& logger) { - auto* src_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Source); - auto* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); - - ORT_ENFORCE(src_node || dst_node, "At least one graph node must be specified in the propagation edge."); - - const auto& base_name = insertion_edge.arg_name; +// Makes matching attributes for new QuantizeLinear nodes from an existing DequantizeLinear node. +NodeAttributes MakeQAttrsFromDQ(const Node& dq_node) { + assert(dq_node.SinceVersion() <= 21); // Checked by previous call to QDQ::MatchDQNode(). + // In opset <= 21, all DQ attributes (i.e., axis and block_size) are also Q attributes. + // So, set a copy of the DQ attributes. + return dq_node.GetAttributes(); +} + +// Makes matching attributes for new DequantizeLinear nodes from an existing QuantizeLinear node. +NodeAttributes MakeDQAttrsFromQ(const Node& q_node) { + assert(q_node.SinceVersion() <= 21); // Checked by previous call to QDQ::MatchQNode(). + const NodeAttributes& q_attrs = q_node.GetAttributes(); + if (q_attrs.empty()) { + return {}; + } + + // In opset <= 21, only the "axis" and "block_size" attributes for Q are also DQ attributes. + NodeAttributes dq_attrs; + + auto axis_attr_it = q_attrs.find("axis"); + if (axis_attr_it != q_attrs.end()) { + dq_attrs.insert({axis_attr_it->first, axis_attr_it->second}); + } + + auto block_size_attr_it = q_attrs.find("block_size"); + if (block_size_attr_it != q_attrs.end()) { + dq_attrs.insert({block_size_attr_it->first, block_size_attr_it->second}); + } + + return dq_attrs; +} + +// Validates edges into which to insert Q -> DQ ops. +// - Must have at least one edge. +// - All edges must correspond to the same graph NodeArg (i.e., same source but potentially different destination). +// - All edges must be attached to either a source node or a destination node. +Status ValidateQDQInsertionEdges(Graph& graph, gsl::span insertion_edges) { + const size_t num_edges = insertion_edges.size(); + ORT_RETURN_IF(num_edges == 0, "Expected at least one edge into which to insert QDQ pair."); + + const ExtendedGraphEdge& first_edge = insertion_edges[0]; + const Node* src_node = first_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Source); + const Node* first_dst_node = first_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + const std::string& node_arg_name = first_edge.arg_name; + ORT_RETURN_IF_NOT(graph.GetNodeArg(node_arg_name) != nullptr, + "QDQ insertion edge does not have a valid graph NodeArg for ", node_arg_name); + ORT_RETURN_IF_NOT(src_node != nullptr || first_dst_node != nullptr, + "QDQ insertion edge [0] for NodeArg ", node_arg_name, + " must have a source or a destination node"); + + for (size_t i = 1; i < num_edges; i++) { + const ExtendedGraphEdge& insertion_edge = insertion_edges[i]; + ORT_RETURN_IF_NOT(insertion_edge.arg_name == node_arg_name, + "QDQ insertion edge [", i, "] has NodeArg ", insertion_edge.arg_name, + " but expected NodeArg ", node_arg_name); + + const Node* edge_dst_node = insertion_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + ORT_RETURN_IF_NOT(src_node != nullptr || edge_dst_node != nullptr, + "QDQ insertion edge [", i, "] for NodeArg ", node_arg_name, + " must have a source or a destination node"); + } + + return Status::OK(); +} + +// Logs information about the edges into which Q/DQ nodes will be inserted in InsertQDQPairs(). +// Assumes the edges have already been validated. +void LogQDQInsertion(const logging::Logger& logger, logging::Severity severity, const CodeLocation& code_location, + const Graph& graph, gsl::span edges) { + auto logging_data_type = logging::DataType::SYSTEM; + if (!logger.OutputIsEnabled(severity, logging_data_type)) { + return; + } + + const Node* src_node = edges[0].GetNodeAtEnd(graph, ExtendedGraphEdge::End::Source); + const auto& node_arg_name = edges[0].arg_name; + std::string src_label = src_node ? MakeString("node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")") + : "input"; + std::ostringstream dst_labels; + const size_t num_edges = edges.size(); + + for (size_t i = 0; i < num_edges; ++i) { + const ExtendedGraphEdge& edge = edges[i]; + const Node* dst_node = edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + dst_labels << (dst_node ? MakeString("dst node (\"", dst_node->Name(), "\", index: ", dst_node->Index(), ")") + : "output") + << (i == num_edges - 1 ? "" : ","); + } + + logging::Capture(logger, severity, logging::Category::onnxruntime, logging_data_type, code_location).Stream() + << "Inserted Q/DQ pair between " + << (src_node ? MakeString("src node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")") + : "input") + << " and " << dst_labels.str() + << " at NodeArg \"" << node_arg_name << "\"."; +} + +// convert this: src_node (or graph input) --+--> dst_node_0 (or graph output) +// | +// +--> dst_node_1 +// | ... +// +--> dst_node_n +// +// to this: src_node (or graph input) -> Q --+--> DQ -> dst_node_0 (or graph output) +// | +// +--> DQ -> dst_node_1 +// | ... +// +--> DQ -> dst_node_n +// Checks that all insertion edges share the same NodeArg. That is, the edges originate from the same source node +// output. If there is no src_node, then all edges should come from the same graph input. +// This function returns an error status if edges are invalid. +// +// Assumes that scale_initializer_nodearg and zp_initializer_nodearg_ptr (if not null) are constant initializers. +Status InsertQDQPairs(Graph& graph, gsl::span insertion_edges, + NodeArg& scale_initializer_nodearg, NodeArg* zp_initializer_nodearg_ptr, + const std::string& qdq_domain, const NodeAttributes& q_attrs, const NodeAttributes& dq_attrs, + const logging::Logger& logger) { + ORT_RETURN_IF_ERROR(ValidateQDQInsertionEdges(graph, insertion_edges)); + + const ExtendedGraphEdge& first_edge = insertion_edges[0]; // ValidateQDQInsertionEdges() guarantees at least one edge + + Node* src_node = first_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Source); // nullptr for graph input + const auto& base_name = first_edge.arg_name; auto& base_node_arg = *graph.GetNodeArg(base_name); - LOGS(logger, VERBOSE) << "Inserting Q/DQ pair between " - << (src_node ? MakeString("node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")") - : "input") - << " and " - << (dst_node ? MakeString("node (\"", dst_node->Name(), "\", index: ", dst_node->Index(), ")") - : "output") - << " at NodeArg \"" << base_name << "\"."; + LogQDQInsertion(logger, logging::Severity::kVERBOSE, ORT_WHERE, graph, insertion_edges); - // set up new NodeArgs - auto& pre_q_nodearg = insertion_edge.HasGraphInputOrInitializer() + auto make_q_or_dq_inputs = [](NodeArg& data, NodeArg& scale, NodeArg* zero_point) { + return zero_point ? InlinedVector{&data, &scale, zero_point} + : InlinedVector{&data, &scale}; + }; + + // Create Q node that will be inserted after src_node + auto& pre_q_nodearg = first_edge.HasGraphInputOrInitializer() ? base_node_arg : graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_pre_q"), nullptr); @@ -57,17 +170,6 @@ Status InsertQDQPair(Graph& graph, const ExtendedGraphEdge& insertion_edge, auto& q_to_dq_nodearg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_q_to_dq"), nullptr); - auto& post_dq_nodearg = insertion_edge.HasGraphOutput() - ? base_node_arg - : graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_post_dq"), - nullptr); - - // set up new Nodes - auto make_q_or_dq_inputs = [](NodeArg& data, NodeArg& scale, NodeArg* zero_point) { - return zero_point ? std::vector{&data, &scale, zero_point} - : std::vector{&data, &scale}; - }; - auto& q_node = graph.AddNode(graph.GenerateNodeName(base_name + "_q"), QDQ::QOpName, "Inserted by QDQPropagationTransformer", @@ -76,40 +178,61 @@ Status InsertQDQPair(Graph& graph, const ExtendedGraphEdge& insertion_edge, zp_initializer_nodearg_ptr), // outputs {&q_to_dq_nodearg}, - nullptr, // attributes + &q_attrs, // attributes qdq_domain); ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(q_node), "Failed to set op schema for added Q node."); - auto& dq_node = graph.AddNode(graph.GenerateNodeName(base_name + "_dq"), - QDQ::DQOpName, - "Inserted by QDQPropagationTransformer", - // inputs - make_q_or_dq_inputs(q_to_dq_nodearg, scale_initializer_nodearg, - zp_initializer_nodearg_ptr), - // outputs - {&post_dq_nodearg}, - nullptr, // attributes - qdq_domain); - - ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(dq_node), "Failed to set op schema for added DQ node."); - - // set up edges - if (src_node && dst_node) { - graph.RemoveEdge(src_node->Index(), dst_node->Index(), - insertion_edge.src->arg_idx, insertion_edge.dst->arg_idx); - } - if (src_node) { - src_node->MutableOutputDefs()[insertion_edge.src->arg_idx] = &pre_q_nodearg; - graph.AddEdge(src_node->Index(), q_node.Index(), insertion_edge.src->arg_idx, 0); - } + // Remove original edges between src and dst nodes. + for (const auto& insertion_edge : insertion_edges) { + auto* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + + if (dst_node) { + graph.RemoveEdge(src_node->Index(), dst_node->Index(), + insertion_edge.src->arg_idx, insertion_edge.dst->arg_idx); + } + } - graph.AddEdge(q_node.Index(), dq_node.Index(), 0, 0); + // Add edge from src to Q node. + src_node->MutableOutputDefs()[first_edge.src->arg_idx] = &pre_q_nodearg; + graph.AddEdge(src_node->Index(), q_node.Index(), first_edge.src->arg_idx, 0); + } - if (dst_node) { - dst_node->MutableInputDefs()[insertion_edge.dst->arg_idx] = &post_dq_nodearg; - graph.AddEdge(dq_node.Index(), dst_node->Index(), 0, insertion_edge.dst->arg_idx); + // Create a DQ node for each dst node and connect remaining edges. + for (size_t edge_idx = 0; edge_idx < insertion_edges.size(); ++edge_idx) { + const auto& insertion_edge = insertion_edges[edge_idx]; + const std::string edge_suffix = edge_idx == 0 ? "" : std::to_string(edge_idx); + auto& post_dq_nodearg = insertion_edge.HasGraphOutput() + ? base_node_arg + : graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(MakeString(base_name, + "_post_dq", + edge_suffix)), + nullptr); + + auto& dq_node = graph.AddNode(graph.GenerateNodeName(MakeString(base_name, "_dq", edge_suffix)), + QDQ::DQOpName, + "Inserted by QDQPropagationTransformer", + // inputs + make_q_or_dq_inputs(q_to_dq_nodearg, scale_initializer_nodearg, + zp_initializer_nodearg_ptr), + // outputs + {&post_dq_nodearg}, + &dq_attrs, // attributes + qdq_domain); + + ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(dq_node), "Failed to set op schema for added DQ node."); + + Node* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + + // Add edge from Q to DQ + graph.AddEdge(q_node.Index(), dq_node.Index(), 0, 0); + + // Add edge from DQ to dst_node + if (dst_node) { + dst_node->MutableInputDefs()[insertion_edge.dst->arg_idx] = &post_dq_nodearg; + graph.AddEdge(dq_node.Index(), dst_node->Index(), 0, insertion_edge.dst->arg_idx); + } } return Status::OK(); @@ -156,37 +279,39 @@ std::optional GetPreviousPropagationEdge(const Graph& graph, return GetPreviousEdge(graph, *src_node); } -std::optional GetNextEdge(const Graph& graph, const Node& node) { - // for now we can just consider the first output (index 0) +InlinedVector GetNextEdges(const Graph& graph, const Node& node) { + constexpr int node_output_index = 0; // for now we can just consider the first output (index 0) + InlinedVector next_edges; + const auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node, static_cast(node_output_index)); - const auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node, 0); - if (output_edges.empty()) { - // maybe edge to output - return ExtendedGraphEdge::TryCreateFromNodeToOutput(graph, node, 0); + // edges to next nodes + for (const auto& output_edge : output_edges) { + next_edges.push_back(ExtendedGraphEdge::CreateFromValidGraphEdge(output_edge)); } - if (!graph.IsOutput(node.OutputDefs()[0]) && output_edges.size() == 1) { - // single edge to next node - return ExtendedGraphEdge::CreateFromValidGraphEdge(output_edges.front()); + // maybe edge to graph output + auto edge_to_output = ExtendedGraphEdge::TryCreateFromNodeToOutput(graph, node, node_output_index); + if (edge_to_output.has_value()) { + next_edges.push_back(edge_to_output.value()); } - return std::nullopt; + return next_edges; } -std::optional GetNextPropagationEdge(const Graph& graph, - const ExtendedGraphEdge& edge) { +InlinedVector GetNextPropagationEdges(const Graph& graph, + const ExtendedGraphEdge& edge) { if (edge.HasGraphOutput()) { - return std::nullopt; + return {}; } const auto* dst_node = edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); ORT_ENFORCE(dst_node != nullptr); if (!CanNodePropagate(*dst_node)) { - return std::nullopt; + return {}; } - return GetNextEdge(graph, *dst_node); + return GetNextEdges(graph, *dst_node); } class GraphConstantInitializerGetter { @@ -228,21 +353,54 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices, ? dq_node.MutableInputDefs()[QDQ::InputIndex::ZERO_POINT_ID] : nullptr; - const auto edge_after_dq = GetNextEdge(graph, dq_node); - if (!edge_after_dq) { + const InlinedVector edges_after_dq = GetNextEdges(graph, dq_node); + if (edges_after_dq.size() != 1) { continue; } - for (auto curr_edge = GetNextPropagationEdge(graph, *edge_after_dq); - curr_edge.has_value(); - curr_edge = GetNextPropagationEdge(graph, *curr_edge)) { - if (const auto* dst_node = curr_edge->GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); - dst_node && QDQ::MatchQNode(*dst_node)) { - break; + // Utility function to check if any edge out of a node (e.g., Transpose) ends in a Q node. + auto any_edge_ends_in_q = [](Graph& graph, const InlinedVector& edges) -> bool { + for (const auto& edge : edges) { + const auto* edge_dst_node = edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination); + if (edge_dst_node && QDQ::MatchQNode(*edge_dst_node)) { + return true; + } + } + return false; + }; + + // Propagate DQ forward in a BFS traversal of NodeArg edges. A NodeArg "edge group" consists of one or more edges + // that all begin at the same source node's output slot and end at a graph output or a destination node. + // Ex: The subgraph below shows a NodeArg edge group (containing 3 edges) that begins at a + // Transpose, ends at two destination nodes, and produces a graph output. + // DQ -> Transpose --+--> Sigmoid -> ... + // | + // +--> Slice -> ... + // | + // +--> graph_output + std::queue> node_arg_edges; + node_arg_edges.push(GetNextPropagationEdges(graph, edges_after_dq[0])); + + while (!node_arg_edges.empty()) { + const InlinedVector curr_edge_group = std::move(node_arg_edges.front()); + node_arg_edges.pop(); + + // Skip if edge group is empty. Also, to keep things simple, we do not yet handle edge groups in which + // one of the destination nodes is already a QuantizeLinear node. Ex: + // DQ -> Transpose --+--> QuantizeLinear -> ... + // | + // +--> Slice -> ... + if (curr_edge_group.empty() || any_edge_ends_in_q(graph, curr_edge_group)) { + continue; } - ORT_RETURN_IF_ERROR(InsertQDQPair(graph, *curr_edge, dq_scale, dq_zero_point, dq_node.Domain(), logger)); + ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, curr_edge_group, dq_scale, dq_zero_point, dq_node.Domain(), + MakeQAttrsFromDQ(dq_node), dq_node.GetAttributes(), logger)); modified = true; + + for (const auto& edge : curr_edge_group) { + node_arg_edges.push(GetNextPropagationEdges(graph, edge)); + } } } @@ -290,7 +448,8 @@ Status PropagateQBackward(Graph& graph, gsl::span node_indices, break; } - ORT_RETURN_IF_ERROR(InsertQDQPair(graph, *curr_edge, q_scale, q_zero_point, q_node.Domain(), logger)); + ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, InlinedVector{*curr_edge}, q_scale, q_zero_point, + q_node.Domain(), q_node.GetAttributes(), MakeDQAttrsFromQ(q_node), logger)); modified = true; } } diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index a4d1ea3c7cf5..7ef4ced1835f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -166,6 +166,41 @@ bool QOrDQNodeHasConstantScalarScaleAndZeroPoint( return true; } +bool IsQOrDQScalePositiveConstantScalar( + const Node& q_or_dq_node, const GetConstantInitializerFn& get_const_initializer, + const std::filesystem::path& model_path) { + auto q_or_dq_input_defs = q_or_dq_node.InputDefs(); + + ORT_ENFORCE(q_or_dq_input_defs.size() >= 2); + + if (!optimizer_utils::IsScalar(*q_or_dq_input_defs[InputIndex::SCALE_ID])) { + return false; + } + + const ONNX_NAMESPACE::TensorProto* q_or_dq_scale_tensor_proto = + get_const_initializer(q_or_dq_input_defs[InputIndex::SCALE_ID]->Name()); + if (nullptr == q_or_dq_scale_tensor_proto) { + return false; + } + + Initializer q_or_dq_scale(*q_or_dq_scale_tensor_proto, model_path); + + switch (q_or_dq_scale.data_type()) { + case ONNX_NAMESPACE::TensorProto::FLOAT: + return q_or_dq_scale.data()[0] > 0; + + case ONNX_NAMESPACE::TensorProto::FLOAT16: + return q_or_dq_scale.data()[0] > 0; + + case ONNX_NAMESPACE::TensorProto::BFLOAT16: + return q_or_dq_scale.data()[0] > 0; + + default: + assert(false); + return false; + } +} + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) bool MatchQNode(const Node& node) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h index 5d11b8bfd555..008f9972a143 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h @@ -65,6 +65,10 @@ bool QOrDQNodeHasConstantScalarScaleAndZeroPoint( const GetConstantInitializerFn& get_const_initializer, bool& zero_point_exists); +// Checks that the y_scale/x_scale input to the QuantizeLinear/DequantizeLinear node is a positive scalar. +bool IsQOrDQScalePositiveConstantScalar(const Node& q_or_dq_node, const GetConstantInitializerFn& get_const_initializer, + const std::filesystem::path& model_path); + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // Check Q node op type, version, and domain. bool MatchQNode(const Node& node); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 74fecb0427e1..8f99b7409d4f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/initializer.h" @@ -275,8 +278,10 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select } } -DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) +DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction( + int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) : accuracy_level_{accuracy_level}, domain_{kMSDomain}, op_type_{"MatMulNBits"}, @@ -286,7 +291,8 @@ DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level, MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), MoveAll(target, ArgType::kOutput)}; }()}, - intra_op_thread_pool_{intra_op_thread_pool} { + intra_op_thread_pool_{intra_op_thread_pool}, + p_buffered_tensors_{p_buffered_tensors} { ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); } @@ -311,6 +317,7 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, const NodesToOptimize& selected_nodes, Node& replacement_node) const { + ORT_RETURN_IF_NOT(p_buffered_tensors_, "Buffered tensors map cannot be null"); const auto* dq_node = selected_nodes.Input(0); const auto* weight_arg = dq_node->InputDefs()[0]; const auto* scale_arg = dq_node->InputDefs()[1]; @@ -338,24 +345,35 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, // to what we need. But it does not handle external data. Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); - std::optional zp_src; - Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName(weight_arg->Name() + "_T"), - std::vector{N, quant_num, blob_bytes}); - Initializer scale_dst(static_cast(scale_src.data_type()), - graph.GenerateNodeArgName(scale_arg->Name() + "_T"), - std::vector{N * quant_num}); - std::optional zp_dst; + auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT8)->GetElementType(); + auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum(scale_src.data_type())->GetElementType(); + std::optional zp_src_ptr; + auto cpu_allocator = std::make_shared(); + auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); + auto weight_dst_ptr = std::make_unique(uint8_type, + TensorShape{N, quant_num, blob_bytes}, + cpu_allocator); + auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); + auto scale_size = (TensorShape{N, quant_num}).Size(); + auto scale_dst_ptr = std::make_unique(scale_type, + TensorShape{scale_size}, + cpu_allocator); + std::string zp_dst_name; + std::unique_ptr zp_dst_ptr; + auto zp_size = (TensorShape{N, (quant_num + 1) / 2}).Size(); if (zp_tensor_proto) { - zp_src.emplace(*zp_tensor_proto, graph.ModelPath()); - zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName(zp_arg->Name() + "_T"), - std::vector{N * ((quant_num + 1) / 2)}); + zp_src_ptr.emplace(*zp_tensor_proto, graph.ModelPath()); + zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); + zp_dst_ptr = std::make_unique(uint8_type, + TensorShape{zp_size}, + cpu_allocator); } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { - zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), - std::vector{N * ((quant_num + 1) / 2)}); + zp_dst_name = graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"); + zp_dst_ptr = std::make_unique(uint8_type, + TensorShape{zp_size}, + cpu_allocator); + memset(zp_dst_ptr->MutableDataRaw(), 0, zp_dst_ptr->SizeInBytes()); } if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { @@ -363,10 +381,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -376,10 +394,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -391,10 +409,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -405,10 +423,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -417,28 +435,43 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, } } - ONNX_NAMESPACE::TensorProto weight_T_tp; - ONNX_NAMESPACE::TensorProto scale_T_tp; + auto weight_T_tp = utils::TensorToTensorProto(*weight_dst_ptr, weight_dst_name, true); + auto scale_T_tp = utils::TensorToTensorProto(*scale_dst_ptr, scale_dst_name, true); std::optional zp_T_tp; - // TODO(fajin): external_data to memory location to avoid arena allocation - // https://github.com/microsoft/onnxruntime/pull/12465 - weight_dst.ToProto(weight_T_tp); - scale_dst.ToProto(scale_T_tp); - if (zp_dst) { - zp_T_tp.emplace(); - zp_dst->ToProto(zp_T_tp.value()); + if (zp_dst_ptr) { + zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst_ptr, zp_dst_name, true)); } auto& input_defs = replacement_node.MutableInputDefs(); input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp)); replacement_node.MutableInputArgsCount().push_back(1); + if (weight_T_tp.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + // If tensor is too small, tensor proto directly copies data from tensor. The tensor allocated + // here can be directly destructed. + // Only keep the tensor in p_buffered_tensors_ when the tensor proto is using external data location + // and pointing the location to tensor's buffer. + ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(weight_dst_name, std::move(weight_dst_ptr)).second, + "Failed to add buffered tensor ", + weight_dst_name); + } + input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp)); replacement_node.MutableInputArgsCount().push_back(1); + if (scale_T_tp.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(scale_dst_name, std::move(scale_dst_ptr)).second, + "Failed to add buffered tensor ", + scale_dst_name); + } if (zp_T_tp) { input_defs.push_back(&graph_utils::AddInitializer(graph, zp_T_tp.value())); replacement_node.MutableInputArgsCount().push_back(1); + if (zp_T_tp->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(zp_dst_name, std::move(zp_dst_ptr)).second, + "Failed to add buffered tensor ", + zp_dst_name); + } } return Status::OK(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 47821619db65..d25077ca4b49 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -5,10 +5,12 @@ #include #include +#include #include #include "core/optimizer/selectors_actions/actions.h" #include "core/platform/threadpool.h" +#include "core/framework/tensor.h" namespace onnxruntime { @@ -84,7 +86,8 @@ struct MatMulReplaceWithQLinear : public Action { // used together with DQMatMulNodeGroupSelector, which does the sanity check struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { DQMatMulToMatMulNBitsAction(int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool); + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors); private: std::string OpType(const RuntimeState&) const override { return op_type_; } @@ -103,6 +106,7 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { const std::string op_type_; const std::vector value_moves_; concurrency::ThreadPool* intra_op_thread_pool_; + std::unordered_map>* p_buffered_tensors_; }; struct GemmReplaceWithQuant : public Action { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 17e66a3953b9..379d271fbdca 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include +#include +#include + +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/mlas/inc/mlas.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" @@ -35,6 +38,7 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { // 3 nodes. DQ, target, Q. Merge into target and remove DQ and Q. const std::string drop_action_name{"drop"}; const std::string drop_action_no_int16_name{"drop_no_int16_support"}; + const std::string drop_action_no_int16_and_positive_scale_name{"drop_no_int16_support_and_positive_scale"}; NTO::NodeLocation dq{NTO::NodeType::kInput, 0}; NTO::NodeLocation q{NTO::NodeType::kOutput, 0}; @@ -46,19 +50,32 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr drop_action_no_int16 = std::make_unique( std::vector(moves)); // Copy before std::move(moves) + std::unique_ptr drop_action_no_int16_and_positive_scale = std::make_unique( + std::vector(moves)); // Copy before std::move(moves) std::unique_ptr drop_action = std::make_unique(std::move(moves)); #if !defined(ORT_MINIMAL_BUILD) - // Use a separate selector + action that disallows 16-bit types for MaxPool and Resize. + // Use separate selectors & actions for MaxPool and Resize. + // + // They disallow 16-bit types for MaxPool and Resize: // int16 MaxPool is not supported by the ONNX specification. // int16 Resize is not supported by the ORT implementation (although allowed by ONNX). - std::unique_ptr selector_disallow_16bit = std::make_unique(false); + // + // And cannot eliminate the QDQ for MaxPool if the scale is not positive, as a negative + // scale will change the ordering of the elements between quantized & de-quantized values. + std::unique_ptr selector_no_16bit = std::make_unique(false); qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_name, - {{"MaxPool", {12}}, - {"Resize", {}}}, - std::move(selector_disallow_16bit), + {{"Resize", {}}}, + std::move(selector_no_16bit), std::move(drop_action_no_int16)); + std::unique_ptr selector_no_16bit_and_positive_scale = + std::make_unique(false, true, false); + qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_and_positive_scale_name, + {{"MaxPool", {12}}}, + std::move(selector_no_16bit_and_positive_scale), + std::move(drop_action_no_int16_and_positive_scale)); + std::unique_ptr selector = std::make_unique(true); qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_name, {{"Gather", {}}, @@ -70,6 +87,9 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) { std::move(drop_action)); #else qdq_selector_action_registry.RegisterAction(drop_action_no_int16_name, std::move(drop_action_no_int16)); + qdq_selector_action_registry.RegisterAction( + drop_action_no_int16_and_positive_scale_name, + std::move(drop_action_no_int16_and_positive_scale)); qdq_selector_action_registry.RegisterAction(drop_action_name, std::move(drop_action)); #endif } @@ -230,7 +250,8 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. // DQ's weight is int4/uint4. DQ's scale is float/float16. // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. @@ -238,7 +259,8 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi std::unique_ptr action = std::make_unique(qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); + intra_op_thread_pool, + p_buffered_tensors); #if !defined(ORT_MINIMAL_BUILD) std::unique_ptr selector = std::make_unique(); @@ -295,9 +317,11 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { #endif } -SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, - int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) { +SelectorActionRegistry CreateSelectorActionRegistry( + bool is_int8_allowed, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -311,20 +335,24 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, WhereQDQRules(qdq_selector_action_registry); DQMatMulToMatMulNBitsRules(qdq_selector_action_registry, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); + intra_op_thread_pool, + p_buffered_tensors); return qdq_selector_action_registry; } } // namespace -QDQSelectorActionTransformer::QDQSelectorActionTransformer(bool is_int8_allowed, - const SatApplyContextVariant& apply_context, - int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) +QDQSelectorActionTransformer::QDQSelectorActionTransformer( + bool is_int8_allowed, + const SatApplyContextVariant& apply_context, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) : SelectorActionTransformer{ "QDQSelectorActionTransformer", - CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, intra_op_thread_pool), + CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, + intra_op_thread_pool, p_buffered_tensors), apply_context, // this transformer is only compatible with the CPU and DML EP {kCpuExecutionProvider, kDmlExecutionProvider}} { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index ba636f76d190..627ddd35b991 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -3,6 +3,10 @@ #pragma once +#include +#include +#include + #include "core/optimizer/selectors_actions/selector_action_transformer.h" #include "core/mlas/inc/mlas.h" #include "core/platform/threadpool.h" @@ -25,7 +29,8 @@ class QDQSelectorActionTransformer : public SelectorActionTransformer { QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}, int64_t qdq_matmulnbits_accuracy_level = 4, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + std::unordered_map>* p_buffered_tensors = nullptr); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 6e93445c7c5c..203aba2c3dd9 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -150,6 +150,13 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return graph_viewer.GetConstantInitializer(initializer_name, true); }; + if (!allow_nonpositive_scale_) { + // IsQDQPairSupported will check that the scale is the same between q_node and dq_node. + if (!IsQOrDQScalePositiveConstantScalar(q_node, get_const_initializer, graph_viewer.ModelPath())) { + return false; + } + } + return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); } @@ -632,7 +639,7 @@ bool BatchNormalizationNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) { + if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 3)) { return false; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 491a15b62cb0..7e009da39403 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -48,8 +48,9 @@ class NodeGroupSelector { // Zero point and scale are constant scalars and must match class DropQDQNodeGroupSelector : public NodeGroupSelector { public: - explicit DropQDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true) - : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} + explicit DropQDQNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true, + bool allow_nonpositive_scale = true) + : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit), allow_nonpositive_scale_(allow_nonpositive_scale) {} private: bool Check(const GraphViewer& graph_viewer, const Node& node, @@ -58,6 +59,7 @@ class DropQDQNodeGroupSelector : public NodeGroupSelector { bool allow_16bit_; bool allow_4bit_; + bool allow_nonpositive_scale_; }; // Single DQ -> node. @@ -300,8 +302,8 @@ class BaseSelector : public NodeSelector { class DropQDQNodesSelector : public BaseSelector { public: - explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false) - : BaseSelector(std::make_unique(allow_16bit, allow_4bit)) {} + explicit DropQDQNodesSelector(bool allow_16bit = false, bool allow_4bit = false, bool allow_nonpositive_scale = true) + : BaseSelector(std::make_unique(allow_16bit, allow_4bit, allow_nonpositive_scale)) {} }; class DropDQNodesSelector : public BaseSelector { diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index cf70a7d821d7..655364357999 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -168,7 +168,8 @@ Note: This fusion doesn't consider the following case: LayerNormalization */ -Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { +Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); InlinedVector> nodes_to_remove; @@ -299,12 +300,15 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le // Assign provider to this new node. Provider should be same as the provider for old node. skip_layer_norm_node.SetExecutionProviderType(ln_node.GetExecutionProviderType()); } + for (const auto& node : nodes_to_remove) { graph_utils::RemoveNodeOutputEdges(graph, node); graph.RemoveNode(node.get().Index()); } - modified = true; + if (!nodes_to_remove.empty()) { + modified = true; + } return Status::OK(); } diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.cc b/onnxruntime/core/platform/windows/logging/etw_sink.cc index b0f9eaf4f62d..ef42c88a67ba 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.cc +++ b/onnxruntime/core/platform/windows/logging/etw_sink.cc @@ -137,28 +137,45 @@ void NTAPI EtwRegistrationManager::ORT_TL_EtwEnableCallback( EtwRegistrationManager::~EtwRegistrationManager() { std::lock_guard lock(callbacks_mutex_); callbacks_.clear(); - ::TraceLoggingUnregister(etw_provider_handle); + if (initialization_status_ == InitializationStatus::Initialized || + initialization_status_ == InitializationStatus::Initializing) { + std::lock_guard init_lock(init_mutex_); + assert(initialization_status_ != InitializationStatus::Initializing); + if (initialization_status_ == InitializationStatus::Initialized) { + ::TraceLoggingUnregister(etw_provider_handle); + initialization_status_ = InitializationStatus::NotInitialized; + } + } } EtwRegistrationManager::EtwRegistrationManager() { } -void EtwRegistrationManager::LazyInitialize() { - if (!initialized_) { +void EtwRegistrationManager::LazyInitialize() try { + if (initialization_status_ == InitializationStatus::NotInitialized) { std::lock_guard lock(init_mutex_); - if (!initialized_) { // Double-check locking pattern - initialized_ = true; + if (initialization_status_ == InitializationStatus::NotInitialized) { // Double-check locking pattern + initialization_status_ = InitializationStatus::Initializing; etw_status_ = ::TraceLoggingRegisterEx(etw_provider_handle, ORT_TL_EtwEnableCallback, nullptr); if (FAILED(etw_status_)) { ORT_THROW("ETW registration failed. Logging will be broken: " + std::to_string(etw_status_)); } + initialization_status_ = InitializationStatus::Initialized; } } +} catch (...) { + initialization_status_ = InitializationStatus::Failed; + throw; } void EtwRegistrationManager::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext) { + if (initialization_status_ != InitializationStatus::Initialized) { + // Drop messages until manager is fully initialized. + return; + } + std::lock_guard lock(callbacks_mutex_); for (const auto& callback : callbacks_) { (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.h b/onnxruntime/core/platform/windows/logging/etw_sink.h index 3af45b813a62..d6c9ea27b295 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.h +++ b/onnxruntime/core/platform/windows/logging/etw_sink.h @@ -47,6 +47,11 @@ class EtwSink : public ISink { }; class EtwRegistrationManager { + enum class InitializationStatus { NotInitialized, + Initializing, + Initialized, + Failed }; + public: using EtwInternalCallback = std::function dkernel = SafeInt(dilation) * (kernel - 1) + 1; + int64_t dkernel_value = SafeInt(in_dim) + pad_head + pad_tail - dkernel; + return static_cast(static_cast(dkernel_value) / stride + 1); +} + inline Status ComputePad(const int64_t in_dim, const int64_t stride, const int64_t kernel, const int64_t dilation, AutoPadType pad_type, @@ -106,6 +114,15 @@ inline Status ComputePad(const int64_t in_dim, // is retained as is SafeInt legacy_target_size = (SafeInt(in_dim) + stride - 1) / stride; SafeInt pad_needed = (legacy_target_size - 1) * stride + kernel - in_dim; + // out_dim = floor((in_dim + 2p - k) / s) + 1 + // => if (in_dim + 2p - k) is not divisible by s we can remove the floor with following equation: + // out_dim + eps = ((in_dim + 2p - k) / s) + 1 ;where eps is in [0.0, 1.0] + // therefore in edge cases padding can lower calculated above than it should be + SafeInt actual_out_size = ComputeOutputShape(in_dim, stride, kernel, /*dilation*/ 1, + pad_needed, pad_needed); + if (actual_out_size < legacy_target_size) { + pad_needed += 1; + } // make sure padding is symmetric if (force_symmetric_auto_padding) { // Inlining math::roundUpPow2() from util/math.h to avoid bringing in the transitive dependencies. @@ -126,14 +143,6 @@ inline Status ComputePad(const int64_t in_dim, return Status::OK(); } -constexpr inline int64_t ComputeOutputShape(const int64_t in_dim, - const int64_t stride, const int64_t kernel, const int64_t dilation, - const int64_t pad_head, const int64_t pad_tail) { - const SafeInt dkernel = SafeInt(dilation) * (kernel - 1) + 1; - int64_t dkernel_value = SafeInt(in_dim) + pad_head + pad_tail - dkernel; - return static_cast(static_cast(dkernel_value) / stride + 1); -} - inline Status ComputePadAndOutputShape(const int64_t in_dim, const int64_t stride, const int64_t kernel, const int64_t dilation, AutoPadType pad_type, diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc index 0e2171551370..c8670cd54625 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc @@ -83,12 +83,16 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, using namespace CoreML::Specification::MILSpec; // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.activation std::string_view coreml_op_type; + bool add_alpha = false; if (op_type == "Sigmoid") { coreml_op_type = "sigmoid"; } else if (op_type == "Tanh") { coreml_op_type = "tanh"; } else if (op_type == "Relu") { coreml_op_type = "relu"; + } else if (op_type == "LeakyRelu") { + coreml_op_type = "leaky_relu"; + add_alpha = true; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); @@ -96,6 +100,13 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); + + if (add_alpha) { + NodeAttrHelper helper(node); + const auto alpha = helper.Get("alpha", 0.01f); + AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha)); + } + AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); @@ -198,7 +209,7 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp #if defined(COREML_ENABLE_MLPROGRAM) if (input_params.create_mlprogram) { - if (op_type == "PRelu" || op_type == "LeakyRelu") { + if (op_type == "PRelu") { // TODO: ML Program supports this so should be easy to enable return false; } } else diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index ebb3f97895f0..e02186d3aee8 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -309,11 +309,33 @@ COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& n void AddOperationInput(MILSpec::Operation& op, std::string_view input_name, std::string_view value_name) { MILSpec::Argument arg; - arg.mutable_arguments()->Add()->set_name(std::string(value_name)); + arg.mutable_arguments()->Add()->set_name(value_name.data(), value_name.size()); (*op.mutable_inputs())[input_name] = std::move(arg); } +void AddOperationVariadicInput(MILSpec::Operation& op, std::string_view input_name, + const std::vector& value_names) { + MILSpec::Argument arg; + for (const auto& value : value_names) { + arg.mutable_arguments()->Add()->set_name(value.data(), value.size()); + } + + (*op.mutable_inputs())[input_name] = std::move(arg); +} + +void AddIntermediateOperationOutput(COREML_SPEC::MILSpec::Operation& op, std::string_view output_name, + int32_t element_type, std::optional> shape) { + auto& outputs = *op.mutable_outputs(); + auto& output_arg = *outputs.Add(); + output_arg.set_name(output_name.data(), output_name.size()); + + MILSpec::ValueType& value = *output_arg.mutable_type(); + MILSpec::TensorType& tensor_type = *value.mutable_tensortype(); + + SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(element_type), shape, /*convert_scalar*/ true); +} + void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output, std::optional override_element_type) { auto& outputs = *op.mutable_outputs(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index f012e6af0d71..475ce79b0a81 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -129,6 +129,26 @@ COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& n void AddOperationInput(COREML_SPEC::MILSpec::Operation& op, std::string_view input_name, std::string_view value_name); +/// +/// Add a variadic input argument to a MILSpec::Operation +/// +/// Operation to update. +/// The input name defined by the spec for the operation. +/// The input value names. +void AddOperationVariadicInput(COREML_SPEC::MILSpec::Operation& op, std::string_view input_name, + const std::vector& value_names); + +/// Add an output to a MILSpec::Operation for an intermediate operation when the implementation is composed of +/// multiple MLProgram operations. In this case we don't have a NodeArg for the output. +/// +/// Operation to update. +/// Name of the intermediate output. Create using ModelBuilder::GetUniqueName. +/// onnx::TensorProto_DataType element type of the output. +/// int32_t as that is what TensorShapeProto uses to store the value. +/// Shape of the output if known. +void AddIntermediateOperationOutput(COREML_SPEC::MILSpec::Operation& op, std::string_view output_name, + int32_t element_type, std::optional> shape); + /// /// Add an output to a MILSpec::Operation. Name, data type and shape are used from the NodeArg. /// diff --git a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc index 34193318a026..9ea0030290ab 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc @@ -4,6 +4,7 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" @@ -18,27 +19,51 @@ class ConcatOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + bool SupportsMLProgram() const override { return true; } }; Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = model_builder.CreateNNLayer(node); - - layer->mutable_concat()->set_sequenceconcat(false); - - for (const auto* input : node.InputDefs()) { - LOGS(logger, VERBOSE) << "input name " << input->Name(); - *layer->mutable_input()->Add() = input->Name(); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; // NOLINT + + NodeAttrHelper helper(node); + const auto axis = helper.GetInt64("axis"); // required + const auto interleave = false; + + std::unique_ptr op = model_builder.CreateOperation(node, "concat"); + std::vector input_names; + for (const auto* input : node.InputDefs()) { + input_names.emplace_back(input->Name()); + } + AddOperationVariadicInput(*op, "values", input_names); + AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", *axis)); + AddOperationInput(*op, "interleave", model_builder.AddScalarConstant(op->type(), "interleave", interleave)); + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + } else // NOLINT +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + layer->mutable_concat()->set_sequenceconcat(false); + + for (const auto* input : node.InputDefs()) { + LOGS(logger, VERBOSE) << "input name " << input->Name(); + *layer->mutable_input()->Add() = input->Name(); + } + + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); } - - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - - model_builder.AddLayer(std::move(layer)); return Status::OK(); } -bool ConcatOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, +bool ConcatOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); if (input_defs.size() < 2) { @@ -50,23 +75,25 @@ bool ConcatOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa if (!GetShape(*input_defs[0], input_shape, logger)) return false; - auto rank = input_shape.size(); - if (rank != 4) { - // For some reason, the concat in CoreML running on 3d tensor will concat on wrong axis - // Instead of concat on axis 0, it will concat on axis 1 - // Disable Concat support for 3d tensor for now - // TODO, add ExpandDims and Squeeze, 3d -ExpandDims-> 4d -> Concat -Squeeze-> 3d - LOGS(logger, VERBOSE) << "Concat only support 4d shape for now, input is " - << rank << "d shape"; - return false; - } - - NodeAttrHelper helper(node); - auto axis = static_cast(HandleNegativeAxis(helper.Get("axis", 1), rank)); - if (rank != axis + 3) { - LOGS(logger, VERBOSE) << "Concat only support axis to be -3, actual axis: " << axis - << ", actual rank: " << rank; - return false; + if (!input_params.create_mlprogram) { + auto rank = input_shape.size(); + if (rank != 4) { + // For some reason, the concat in CoreML running on 3d tensor will concat on wrong axis + // Instead of concat on axis 0, it will concat on axis 1 + // Disable Concat support for 3d tensor for now + // TODO: add ExpandDims and Squeeze, 3d -ExpandDims-> 4d -> Concat -Squeeze-> 3d + LOGS(logger, VERBOSE) << "Concat only support 4d shape for now, input is " + << rank << "d shape"; + return false; + } + + NodeAttrHelper helper(node); + auto axis = static_cast(HandleNegativeAxis(helper.Get("axis", 1), rank)); + if (rank != axis + 3) { + LOGS(logger, VERBOSE) << "Concat only support axis to be -3, actual axis: " << axis + << ", actual rank: " << rank; + return false; + } } return true; diff --git a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc index 1eba312b2577..bec2461ffbc5 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc @@ -4,6 +4,7 @@ #include "core/common/safeint.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" @@ -18,52 +19,133 @@ class DepthToSpaceOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + bool SupportsMLProgram() const override { return true; } }; Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& /* logger */) const { - std::unique_ptr layer = model_builder.CreateNNLayer(node); - + [[maybe_unused]] const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& output_defs = node.OutputDefs(); const auto& input_name = input_defs[0]->Name(); - const auto& output_name = output_defs[0]->Name(); - uint64_t blocksize = SafeInt(node.GetAttributes().at("blocksize").i()); + NodeAttrHelper helper(node); + int64_t blocksize = *helper.GetInt64("blocksize"); // required attribute + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; // NOLINT + + const auto mode = helper.Get("mode", "DCR"); + + if (mode == "DCR") { + // DCR is directly supported + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.depth_to_space + // Validated with depth_to_space.py. + auto op = model_builder.CreateOperation(node, "depth_to_space"); + AddOperationInput(*op, "x", input_name); + AddOperationInput(*op, "block_size", model_builder.AddScalarConstant(op->type(), "blocksize", blocksize)); + AddOperationOutput(*op, *output_defs[0]); + model_builder.AddOperation(std::move(op)); + } else { + // CRD is manual. there may be a perf cost from the Reshape's (typically that happens on CPU) but if the input + // is a fixed size hopefully CoreML is smart enough to handle that aspect during model compilation instead + // of execution. + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#depthtospace + // b, c, h, w = x.shape + // tmp = np.reshape(x, [b, c // (blocksize ** 2), blocksize, blocksize, h, w]) + // tmp = np.transpose(tmp, [0, 1, 4, 2, 5, 3]) + // y = np.reshape(tmp, [b, c // (blocksize ** 2), h * blocksize, w * blocksize]) + // + // CoreML has a 5D limit, so we merge the batch dim into the channel dim as that doesn't change the data + // movement. + // First reshape is to [b * c // (blocksize ** 2), blocksize, blocksize, h, w] + // Transpose is to [0, 3, 1, 4, 2] + + // we checked shape was static in IsOpSupportedImpl so this should never fail + std::vector input_shape; + ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Failed to get input shape"); + const int32_t elem_type = static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + // reshape to [b * c // (blocksize ** 2), blocksize, blocksize, h, w] + auto reshape1 = model_builder.CreateOperation(node, "reshape", "pre"); + std::vector shape1 = {input_shape[0] * input_shape[1] / (blocksize * blocksize), + blocksize, blocksize, input_shape[2], input_shape[3]}; + AddOperationInput(*reshape1, "x", input_name); + AddOperationInput(*reshape1, "shape", model_builder.AddConstant(reshape1->type(), "shape", shape1)); + const auto& reshape1_output = model_builder.GetUniqueName(node, "reshape1"); + AddIntermediateOperationOutput(*reshape1, reshape1_output, elem_type, shape1); + + // transpose to [0, 3, 1, 4, 2] + auto transpose = model_builder.CreateOperation(node, "transpose"); + std::vector perm = {0, 3, 1, 4, 2}; + std::vector shape2 = {shape1[0], shape1[3], shape1[1], shape1[4], shape1[2]}; + AddOperationInput(*transpose, "x", reshape1_output); + AddOperationInput(*transpose, "perm", model_builder.AddConstant(transpose->type(), "perm", perm)); + const auto& transpose_output = model_builder.GetUniqueName(node, "transpose"); + AddIntermediateOperationOutput(*transpose, transpose_output, elem_type, shape2); + + // reshape to [b, c // (blocksize ** 2), h * blocksize, w * blocksize] + auto reshape2 = model_builder.CreateOperation(node, "reshape", "post"); + std::vector shape3 = {input_shape[0], + input_shape[1] / (blocksize * blocksize), + input_shape[2] * blocksize, + input_shape[3] * blocksize}; + AddOperationInput(*reshape2, "x", transpose_output); + AddOperationInput(*reshape2, "shape", model_builder.AddConstant(reshape2->type(), "shape", shape3)); + + AddOperationOutput(*reshape2, *output_defs[0]); + + model_builder.AddOperation(std::move(reshape1)); + model_builder.AddOperation(std::move(transpose)); + model_builder.AddOperation(std::move(reshape2)); + } + } else // NOLINT +#endif // if defined(COREML_ENABLE_MLPROGRAM) + { + const auto& output_name = output_defs[0]->Name(); + std::unique_ptr layer = model_builder.CreateNNLayer(node); - auto* coreml_depthtospace = layer->mutable_reorganizedata(); - coreml_depthtospace->set_blocksize(blocksize); - coreml_depthtospace->set_mode(CoreML::Specification::ReorganizeDataLayerParams_ReorganizationType:: - ReorganizeDataLayerParams_ReorganizationType_DEPTH_TO_SPACE); + auto* coreml_depthtospace = layer->mutable_reorganizedata(); + coreml_depthtospace->set_blocksize(static_cast(blocksize)); + coreml_depthtospace->set_mode(CoreML::Specification::ReorganizeDataLayerParams_ReorganizationType:: + ReorganizeDataLayerParams_ReorganizationType_DEPTH_TO_SPACE); - *layer->mutable_input()->Add() = input_name; - *layer->mutable_output()->Add() = output_name; + *layer->mutable_input()->Add() = input_name; + *layer->mutable_output()->Add() = output_name; + + model_builder.AddLayer(std::move(layer)); + } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } -bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, +bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); std::vector input_shape; if (!GetShape(*input_defs[0], input_shape, logger)) { + LOGS(logger, VERBOSE) << "DepthToSpace: no input shape"; return false; } - const auto input_rank = input_shape.size(); - if (input_rank < 4) { - LOGS(logger, VERBOSE) << "DepthToSpace does not support input shape of " << input_rank << "d shape."; - } + // ONNX and CoreML both require 4D input so no need to check the shape here. NodeAttrHelper helper(node); - if (node.SinceVersion() >= 11) { - // For now, only DCR mode DepthToSpace is supported - const auto mode = helper.Get("mode", "DCR"); + const auto mode = helper.Get("mode", "DCR"); + + if (input_params.create_mlprogram) { + if (mode == "CRD" && !IsStaticShape(input_shape)) { + // we need to manually implement the logic with a Reshape, so we need to know the shape to do that + LOGS(logger, VERBOSE) << "DepthToSpace: CRD mode requires static shape"; + return false; + } + } else { if (mode != "DCR") { - LOGS(logger, VERBOSE) << "The mode: " << mode << "of DepthToSpace is not supported in CoreML EP for now."; + LOGS(logger, VERBOSE) << "DepthToSpace: " << mode << " mode is not supported"; return false; } } diff --git a/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc index bfc665e0ac71..9caec290ea5a 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc @@ -19,8 +19,8 @@ std::string_view GetMode(const NodeAttrHelper& helper) { // opset 20+ uses linear, nearest, cubic // bilinear is what CoreML uses, so prefer that // bicubic/cubic isn't supported - - const auto& mode = helper.Get("mode", "linear"); + static const std::string default_mode = "linear"; // static in case we ever return the default as a string_view + const auto& mode = helper.Get("mode", default_mode); if (mode == "linear") { return "bilinear"; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index 0497357c45c5..dbd0f48576f8 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -5,6 +5,7 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" @@ -24,6 +25,8 @@ class SplitOpBuilder : public BaseOpBuilder { // Split opset 13- uses "split" as attribute. Currently it's not supported. int GetMinSupportedOpSet(const Node& /* node */) const override { return 13; } + + bool SupportsMLProgram() const override { return true; } }; void SplitOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -43,55 +46,98 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(GetShape(*node.InputDefs()[0], data_shape, logger), "Failed to get input shape."); NodeAttrHelper helper(node); - const auto axis = helper.Get("axis", 0); + int64_t axis = helper.Get("axis", 0); - // attribute introduced since opset 18 - uint64_t num_outputs; - - std::unique_ptr layer = model_builder.CreateNNLayer(node); - auto* coreml_splitnd = layer->mutable_splitnd(); - coreml_splitnd->set_axis(axis); - - if (input_defs.size() > 1) { - // if "split" is explicitly provided as an input - const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); - Initializer unpacked_tensor(split_tensor); - auto split_span = unpacked_tensor.DataAsSpan(); - auto split_sizes = split_span.size(); - num_outputs = narrow(split_sizes); - for (size_t i = 0; i < split_sizes; i++) { - coreml_splitnd->add_splitsizes(split_span[i]); - } - } else if (node.SinceVersion() < 18) { - num_outputs = narrow(node.OutputDefs().size()); - coreml_splitnd->set_numsplits(num_outputs); - } else { - // note: for opset 18+ 'num_outputs' is a required attribute - num_outputs = narrow(helper.GetInt64("num_outputs").value()); + auto calculate_remainder_and_chunk_size = [&](int32_t num_outputs) { // note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())]; - uint64_t chunk_size = narrow((split_dim_size + num_outputs - 1) / num_outputs); + uint64_t chunk_size = (split_dim_size + num_outputs - 1) / num_outputs; uint64_t remainder = split_dim_size % chunk_size; - if (remainder) { - // uneven - auto split_sizes = InlinedVector(num_outputs, chunk_size); - split_sizes.back() = remainder; - for (size_t i = 0; i < split_sizes.size(); i++) { - coreml_splitnd->add_splitsizes(split_sizes[i]); - } + return std::make_tuple(remainder, chunk_size); + }; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + std::unique_ptr split_op = model_builder.CreateOperation(node, "split"); + AddOperationInput(*split_op, "axis", model_builder.AddScalarConstant(split_op->type(), "axis", axis)); + + if (input_defs.size() > 1) { + // if "split" is explicitly provided as an input + Initializer unpacked_tensor(*model_builder.GetConstantInitializer(input_defs[1]->Name())); + auto split_span = unpacked_tensor.DataAsSpan(); + AddOperationInput(*split_op, "split_sizes", + model_builder.AddConstant(split_op->type(), "split_sizes", split_span)); + } else if (node.SinceVersion() < 18) { + int64_t num_outputs = narrow(node.OutputDefs().size()); + AddOperationInput(*split_op, "num_splits", + model_builder.AddScalarConstant(split_op->type(), "num_splits", num_outputs)); } else { - // even + // note: for opset 18+ 'num_outputs' is a required attribute + int64_t num_outputs = helper.GetInt64("num_outputs").value(); + auto [remainder, chunk_size] = calculate_remainder_and_chunk_size(static_cast(num_outputs)); + if (remainder) { + // uneven + std::vector split_sizes(num_outputs, chunk_size); + split_sizes.back() = remainder; + AddOperationInput(*split_op, "split_sizes", + model_builder.AddConstant(split_op->type(), "split_sizes", split_sizes)); + } else { + // even + AddOperationInput(*split_op, "num_splits", + model_builder.AddScalarConstant(split_op->type(), "num_splits", num_outputs)); + } + } + + AddOperationInput(*split_op, "x", input_defs[0]->Name()); + for (const auto& output_def : node.OutputDefs()) { + AddOperationOutput(*split_op, *output_def); + } + model_builder.AddOperation(std::move(split_op)); + + } else +#endif + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + auto* coreml_splitnd = layer->mutable_splitnd(); + coreml_splitnd->set_axis(axis); + + if (input_defs.size() > 1) { + // if "split" is explicitly provided as an input + // const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); + Initializer unpacked_tensor(*model_builder.GetConstantInitializer(input_defs[1]->Name())); + auto split_span = unpacked_tensor.DataAsSpan(); + for (const auto& split_size : split_span) { + coreml_splitnd->add_splitsizes(split_size); + } + } else if (node.SinceVersion() < 18) { + uint64_t num_outputs = narrow(node.OutputDefs().size()); coreml_splitnd->set_numsplits(num_outputs); + } else { + // note: for opset 18+ 'num_outputs' is a required attribute + uint64_t num_outputs = narrow(helper.GetInt64("num_outputs").value()); + auto [remainder, chunk_size] = calculate_remainder_and_chunk_size(static_cast(num_outputs)); + if (remainder) { + // uneven + auto split_sizes = InlinedVector(num_outputs, chunk_size); + split_sizes.back() = remainder; + for (size_t i = 0; i < split_sizes.size(); i++) { + coreml_splitnd->add_splitsizes(split_sizes[i]); + } + } else { + // even + coreml_splitnd->set_numsplits(num_outputs); + } } - } - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - // variadic number of outputs. Calculated based on the length of the given splitSizes if provided. - // Otherwise, uses attribute value 'num_outputs'. - for (uint64_t i = 0; i < num_outputs; i++) { - *layer->mutable_output()->Add() = node.OutputDefs()[i]->Name(); + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + // variadic number of outputs. Calculated based on the length of the given splitSizes if provided. + // Otherwise, uses attribute value 'num_outputs'. + for (const auto& output_def : node.OutputDefs()) { + *layer->mutable_output()->Add() = output_def->Name(); + } + model_builder.AddLayer(std::move(layer)); } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } @@ -99,7 +145,6 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); NodeAttrHelper helper(node); const auto axis = helper.Get("axis", 0); @@ -110,16 +155,19 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar const auto split_dims_at_axis = input_shape[HandleNegativeAxis(axis, input_shape.size())]; if (input_defs.size() > 1 && input_defs[1]->Exists()) { - if (!CheckIsConstantInitializer(*input_defs[1], input_params.graph_viewer, logger, "'split'")) { + const auto* splits_tensor = input_params.graph_viewer.GetConstantInitializer(input_defs[1]->Name()); + if (!splits_tensor) { + LOGS(logger, VERBOSE) << "CoreML 'splits' input must be a constant initializer."; return false; } + const auto split_shape = *input_defs[1]->Shape(); if (split_shape.dim_size() < 2) { - LOGS(logger, VERBOSE) << "CoreML SplitND requires to produce at least 2 outputs."; + LOGS(logger, VERBOSE) << "CoreML Split must produce at least 2 outputs."; return false; } - const auto& splits_tensor = *initializers.at(input_defs[1]->Name()); - Initializer unpacked_tensor(splits_tensor); + + Initializer unpacked_tensor(*splits_tensor); auto splits_span = unpacked_tensor.DataAsSpan(); int64_t sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), int64_t{0}); if (sum_of_splits != split_dims_at_axis) { diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc index 535712f09601..b0006b24e7d7 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc @@ -15,28 +15,28 @@ namespace coreml { static OpBuilderRegistrations CreateOpBuilderRegistrations() { OpBuilderRegistrations op_registrations; + // Activations + CreateActivationOpBuilder("Sigmoid", op_registrations); + CreateActivationOpBuilder("Tanh", op_registrations); + CreateActivationOpBuilder("Relu", op_registrations); + CreateActivationOpBuilder("PRelu", op_registrations); + CreateActivationOpBuilder("LeakyRelu", op_registrations); + // Unary ops - CreateUnaryOpBuilder("Sqrt", op_registrations); CreateUnaryOpBuilder("Reciprocal", op_registrations); + CreateUnaryOpBuilder("Sqrt", op_registrations); // Binary elementwise ops CreateBinaryOpBuilder("Add", op_registrations); + CreateBinaryOpBuilder("Div", op_registrations); CreateBinaryOpBuilder("Mul", op_registrations); CreateBinaryOpBuilder("Pow", op_registrations); CreateBinaryOpBuilder("Sub", op_registrations); - CreateBinaryOpBuilder("Div", op_registrations); - - // Activations - CreateActivationOpBuilder("Sigmoid", op_registrations); - CreateActivationOpBuilder("Tanh", op_registrations); - CreateActivationOpBuilder("Relu", op_registrations); - CreateActivationOpBuilder("PRelu", op_registrations); - CreateActivationOpBuilder("LeakyRelu", op_registrations); // Pooling ops + CreatePoolOpBuilder("AveragePool", op_registrations); CreatePoolOpBuilder("GlobalAveragePool", op_registrations); CreatePoolOpBuilder("GlobalMaxPool", op_registrations); - CreatePoolOpBuilder("AveragePool", op_registrations); CreatePoolOpBuilder("MaxPool", op_registrations); // Reduction ops @@ -54,6 +54,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateFlattenOpBuilder("Flatten", op_registrations); CreateGatherOpBuilder("Gather", op_registrations); CreateGemmOpBuilder("Gemm", op_registrations); + CreateGridSampleOpBuilder("GridSample", op_registrations); CreateLRNOpBuilder("LRN", op_registrations); CreateGemmOpBuilder("MatMul", op_registrations); CreatePadOpBuilder("Pad", op_registrations); @@ -66,8 +67,6 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateSqueezeOpBuilder("Squeeze", op_registrations); CreateTransposeOpBuilder("Transpose", op_registrations); - CreateGridSampleOpBuilder("GridSample", op_registrations); - return op_registrations; } diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index a92fef81ac39..f2cd4d01174d 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -83,7 +83,9 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie }; result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {}, - gen_metadef_name, COREML, kCoreMLExecutionProvider); + gen_metadef_name, COREML, kCoreMLExecutionProvider, + nullptr, + /*drop_constant_initializers*/ true); const auto num_of_partitions = result.size(); const auto num_of_supported_nodes = std::transform_reduce( diff --git a/onnxruntime/core/providers/coreml/DebugMLProgram.md b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/DebugMLProgram.md similarity index 97% rename from onnxruntime/core/providers/coreml/DebugMLProgram.md rename to onnxruntime/core/providers/coreml/mlprogram_test_scripts/DebugMLProgram.md index e41a51559430..b7a54466ab8d 100644 --- a/onnxruntime/core/providers/coreml/DebugMLProgram.md +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/DebugMLProgram.md @@ -25,6 +25,8 @@ https://apple.github.io/coremltools/docs-guides/source/model-intermediate-langua Usage is reasonably intuitive. The below example defines a model with 2 inputs and a matmul operator. The model is printed, and run with randomly generated inputs. The output from doing so is printed. +There are additional test scripts in this directory for different operators. + ```python import numpy as np import coremltools as ct diff --git a/onnxruntime/core/providers/coreml/mlprogram_test_scripts/concat_test.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/concat_test.py new file mode 100644 index 000000000000..430a2b3fa3ed --- /dev/null +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/concat_test.py @@ -0,0 +1,33 @@ +import coremltools as ct +import numpy as np +from coremltools.converters.mil import Builder as mb + +target = ct.target.iOS15 + +a_shape = (1, 1, 3, 3) + + +@mb.program( + input_specs=[mb.TensorSpec(shape=a_shape), mb.TensorSpec(shape=a_shape), mb.TensorSpec(shape=a_shape)], + opset_version=target, +) +def prog(x, y, z): + axis = mb.const(val=1) + interleave = mb.const(val=False) + z = mb.concat(values=(x, y, z), axis=axis, interleave=interleave) + return z + + +print(prog) + +# Convert to ML program +m = ct.convert(prog, minimum_deployment_target=target, compute_precision=ct.precision.FLOAT32) + +x = np.random.rand(*a_shape) +y = np.random.rand(*a_shape) +z = np.random.rand(*a_shape) + +# spec = m.get_spec() +# print(spec) + +print(m.predict({"x": x, "y": y, "z": z})) diff --git a/onnxruntime/core/providers/coreml/mlprogram_test_scripts/convtranspose_test.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/convtranspose_test.py new file mode 100644 index 000000000000..2c8cbc4948a6 --- /dev/null +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/convtranspose_test.py @@ -0,0 +1,42 @@ +import coremltools as ct +import numpy as np +from coremltools.converters.mil import Builder as mb + +target = ct.target.iOS15 + +x_shape = (1, 3, 4, 4) +w_shape = (3, 3, 3, 3) + + +@mb.program(input_specs=[mb.TensorSpec(shape=x_shape)], opset_version=target) +def prog(x): + weight = mb.const(name="weight", val=np.ones(w_shape, dtype=np.float32)) + output_shape = mb.const(name="output_shape", val=np.array([1, 3, 4, 4])) + # pad = mb.const(val=np.zeros((4), dtype=np.int32)) + strides = mb.const(name="strides", val=np.ones((2), dtype=np.int32)) + dilations = mb.const(name="dilations", val=np.ones((2), dtype=np.int32)) + z = mb.conv_transpose( + x=x, weight=weight, strides=strides, dilations=dilations, output_shape=output_shape + ) # , pad=pad + + return z + + +print(prog) + +# Convert to ML program +m = ct.convert(prog, minimum_deployment_target=target, compute_precision=ct.precision.FLOAT32) + +# spec = m.get_spec() +# print(spec) + +m.save("ConvTranspose.mlpackage") +# construct MLModel with compute_units=ComputeUnit.CPU and run predict +m_cpu = ct.models.MLModel("ConvTranspose.mlpackage", compute_units=ct.ComputeUnit.CPU_ONLY) +m_all = ct.models.MLModel("ConvTranspose.mlpackage", compute_units=ct.ComputeUnit.ALL) + +x = np.ones(x_shape, dtype=np.float32) +print("CPU_ONLY") +print(m_cpu.predict({"x": x})) +print("ALL") +print(m_all.predict({"x": x})) diff --git a/onnxruntime/core/providers/coreml/mlprogram_test_scripts/depthtospace_test.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/depthtospace_test.py new file mode 100644 index 000000000000..593d9e8bbf66 --- /dev/null +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/depthtospace_test.py @@ -0,0 +1,51 @@ +import coremltools as ct +import numpy as np +from coremltools.converters.mil import Builder as mb + +target = ct.target.iOS15 + +# replicate example from https://github.com/onnx/onnx/blob/main/docs/Operators.md#depthtospace +# to prove CoreML mode is DCR +x_shape = (1, 8, 2, 3) + + +@mb.program(input_specs=[mb.TensorSpec(shape=x_shape)], opset_version=target) +def prog(x): + block_size = mb.const(name="block_size", val=2) + z = mb.depth_to_space(x=x, block_size=block_size) + return z + + +print(prog) + +# Convert to ML program +m = ct.convert(prog, minimum_deployment_target=target, compute_precision=ct.precision.FLOAT32) + +# spec = m.get_spec() +# print(spec) + +m.save("DepthToSpace.mlpackage") + +# also check for differences between CPU_ONLY and ALL +m_cpu = ct.models.MLModel("DepthToSpace.mlpackage", compute_units=ct.ComputeUnit.CPU_ONLY) +m_all = ct.models.MLModel("DepthToSpace.mlpackage", compute_units=ct.ComputeUnit.ALL) + +x = np.array( + [ + [ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[9.0, 10.0, 11.0], [12.0, 13.0, 14.0]], + [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], + [[27.0, 28.0, 29.0], [30.0, 31.0, 32.0]], + [[36.0, 37.0, 38.0], [39.0, 40.0, 41.0]], + [[45.0, 46.0, 47.0], [48.0, 49.0, 50.0]], + [[54.0, 55.0, 56.0], [57.0, 58.0, 59.0]], + [[63.0, 64.0, 65.0], [66.0, 67.0, 68.0]], + ] + ] +).astype(np.float32) + +print("CPU_ONLY") +print(m_cpu.predict({"x": x})) +print("ALL") +print(m_all.predict({"x": x})) diff --git a/onnxruntime/core/providers/coreml/mlprogram_test_scripts/div_test.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/div_test.py new file mode 100644 index 000000000000..a0423511598f --- /dev/null +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/div_test.py @@ -0,0 +1,103 @@ +import coremltools as ct +import numpy as np +from coremltools.converters.mil import Builder as mb +from coremltools.models import datatypes +from coremltools.models.neural_network import NeuralNetworkBuilder +from coremltools.models.utils import save_spec + +input_dim = (1,) +output_dim = (1,) + + +def mlprogram(): + target = ct.target.iOS15 + + @mb.program(input_specs=[mb.TensorSpec(shape=input_dim), mb.TensorSpec(shape=input_dim)], opset_version=target) + def prog(x, y): + return mb.real_div(x=x, y=y) + + # print(prog) + + # Convert to ML program + m = ct.convert(prog, minimum_deployment_target=target) + + x = np.array([2], dtype=np.float32) + y = np.array([2047], dtype=np.float32) + + # spec = m.get_spec() + # print(spec) + + print(m.predict({"x": x, "y": y})) + + +# implement Div with coremltools approach of x * (1/y) +def nn(): + input_features = [("x", datatypes.Array(*input_dim)), ("y_inv", datatypes.Array(*input_dim))] + output_features = [("final", datatypes.Array(*output_dim))] + + # Build a simple neural network with 1 inner product layer + builder = NeuralNetworkBuilder(input_features, output_features) + builder.add_elementwise( + name="x_multiply_inverse_of_y", + input_names=["x", "y_inv"], + output_name="final", + mode="MULTIPLY", + ) + + save_spec(builder.spec, "network.mlmodel") + m = ct.models.MLModel("network.mlmodel") + + x = np.array([2], dtype=np.float32) + y = np.array([1 / 2047], dtype=np.float32) + print(m.predict({"x": x, "y_inv": y})) + + +def nn_scale(): + input_features = [ + ("x", datatypes.Array(*input_dim)), + ("y_inv", datatypes.Array(*input_dim)), + ("z", datatypes.Array(*input_dim)), + ] + output_features = [("final", datatypes.Array(*output_dim))] + + builder = NeuralNetworkBuilder(input_features, output_features) + + builder.add_elementwise( + name="div_implemented_as_x_multiply_inverse_of_y", + input_names=["x", "y_inv"], + output_name="div_result", + mode="MULTIPLY", + ) + + builder.add_elementwise( + name="apply_scaling_factor", + input_names=["div_result", "z"], + output_name="final", + mode="MULTIPLY", + ) + + from coremltools.models.utils import save_spec + + save_spec(builder.spec, "network.mlmodel") + m = ct.models.MLModel("network.mlmodel") + + a = 2 + b = 2047 + # scaling factor to test working around coremltools inaccuracy. + # weirdly even a scaling factor of 1 fixes the problem from https://github.com/microsoft/onnxruntime/issues/21170 + c = 1000 + + x = np.array([a], dtype=np.float32) + y = np.array([1 / b / c], dtype=np.float32) + z = np.array([c], dtype=np.float32) + print(m.predict({"x": x, "y_inv": y, "z": z})) + + +print("NN") +nn() + +print("\nNN with scaling") +nn_scale() + +print("\nML Program") +mlprogram() diff --git a/onnxruntime/core/providers/coreml/dump_mlprogram_model.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/dump_mlprogram_model.py similarity index 100% rename from onnxruntime/core/providers/coreml/dump_mlprogram_model.py rename to onnxruntime/core/providers/coreml/mlprogram_test_scripts/dump_mlprogram_model.py diff --git a/onnxruntime/core/providers/coreml/mlprogram_test_scripts/gridsample_test.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/gridsample_test.py new file mode 100644 index 000000000000..5ce79c204c00 --- /dev/null +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/gridsample_test.py @@ -0,0 +1,114 @@ +import coremltools as ct +import numpy as np +from coremltools.converters.mil import Builder as mb + +target = ct.target.iOS15 + +x_shape = (2, 2, 3, 2) +grid_shape = (2, 3, 2, 2) + + +@mb.program(input_specs=[mb.TensorSpec(shape=x_shape), mb.TensorSpec(shape=grid_shape)], opset_version=target) +def prog(x, grid): + sampling = mb.const(name="sampling_mode", val="bilinear") + padding_mode = mb.const(name="pmode", val="reflection") + pad = mb.const(name="pval", val=np.float32(0)) + coord_mode = mb.const(name="coord_mode", val="normalized_minus_one_to_one") + align_corners = mb.const(name="align_corners", val=False) + z = mb.resample( + x=x, + coordinates=grid, + sampling_mode=sampling, + padding_mode=padding_mode, + padding_value=pad, + coordinates_mode=coord_mode, + align_corners=align_corners, + ) + + return z + + +# print(prog) + +# Convert to ML program +m = ct.convert(prog, minimum_deployment_target=target, compute_precision=ct.precision.FLOAT32) + +# spec = m.get_spec() +# print(spec) + +m.save("GridSample.mlpackage") +# construct MLModel with compute_units=ComputeUnit.CPU and run predict +m_cpu = ct.models.MLModel("GridSample.mlpackage", compute_units=ct.ComputeUnit.CPU_ONLY) +m_all = ct.models.MLModel("GridSample.mlpackage", compute_units=ct.ComputeUnit.ALL) + +# GridSampleTest.test_grid_sample_20_4D_bilinear_reflection_no_align_corners +# ORT produces different output for this test. ORT output is generated by pytorch +x = ( + np.array( + [ + -0.173652, + -1.513725, + -0.704586, + -1.952375, + -0.699404, + -0.806298, + 1.640852, + -0.138969, + -0.695411, + -1.352111, + 0.568797, + -0.564294, + -0.056468, + 0.641604, + -0.438370, + 0.450167, + -1.091401, + 1.669729, + -0.908544, + 0.244467, + 0.172109, + 1.156741, + -0.617128, + 1.155460, + ] + ) + .astype(np.float32) + .reshape(x_shape) +) + +grid = ( + np.array( + [ + 0.252250, + -0.151452, + 0.824706, + -0.588292, + -0.591147, + -0.155082, + -0.732938, + 0.457493, + -0.439559, + 0.492330, + 0.696447, + 0.700722, + -0.220298, + 0.654884, + -0.635434, + -1.195619, + -0.114204, + -0.870080, + -0.929674, + 0.305035, + 1.025429, + -0.472240, + -0.067881, + -0.869393, + ] + ) + .astype(np.float32) + .reshape(grid_shape) +) + + +print(m_cpu.predict({"x": x, "grid": grid})) +print(m_all.predict({"x": x, "grid": grid})) diff --git a/onnxruntime/core/providers/coreml/mlprogram_test_scripts/resize_test.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/resize_test.py new file mode 100644 index 000000000000..f83dc6ddfe02 --- /dev/null +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/resize_test.py @@ -0,0 +1,51 @@ +import coremltools as ct +import numpy as np +from coremltools.converters.mil import Builder as mb + +target = ct.target.iOS15 + +x_shape = (1, 1, 3, 6) + +use_scale = False # set this to test upsample vs resize + + +@mb.program(input_specs=[mb.TensorSpec(shape=x_shape)], opset_version=target) +def prog(x): + global use_scale # noqa + + if use_scale: + align = mb.const(val=False) + scale_h = mb.const(val=float(1 / 3)) + scale_w = mb.const(val=float(1 / 3)) + z = mb.upsample_bilinear(x=x, scale_factor_height=scale_h, scale_factor_width=scale_w, align_corners=align) + else: + size_h = mb.const(val=1) + size_w = mb.const(val=2) + sampling_mode = mb.const(val="UNALIGN_CORNERS") + z = mb.resize_bilinear(x=x, target_size_height=size_h, target_size_width=size_w, sampling_mode=sampling_mode) + + return z + + +print(prog) + +# Convert to ML program +m = ct.convert(prog, minimum_deployment_target=target, compute_precision=ct.precision.FLOAT32) + +x = np.array( + [ + [ + [ + [1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18], + ] + ] + ], + dtype=np.float32, +) + +# spec = m.get_spec() +# print(spec) + +print(m.predict({"x": x})) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 7ac68e3a9a69..7ed776f1358a 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -179,12 +179,15 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, GlobalMaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MaxRoiPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, @@ -202,11 +205,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSum); @@ -407,12 +412,15 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, @@ -432,11 +440,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, int64_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSum); @@ -724,12 +734,15 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, GatherND); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, @@ -751,6 +764,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceMin); @@ -758,6 +772,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int8_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, uint8_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, double, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int32_t, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, int64_t, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 17, float, @@ -881,12 +896,15 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 18, int8_t, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 18, uint8_t, Resize); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceL1); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL1); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceL1); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceLogSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceLogSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceLogSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceLogSumExp); @@ -902,6 +920,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceMean); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceMean); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, float, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, double, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int32_t, ReduceMin); @@ -909,6 +928,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int8_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, uint8_t, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare); @@ -1343,18 +1363,24 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2165,18 +2205,24 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo(group_count) * packed_W_size_; @@ -472,6 +473,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { MlasConvDepthwise( worker_indirection_buffer, reordered_W, + Bdata, worker_output, static_cast(M), static_cast(output_count), diff --git a/onnxruntime/core/providers/cpu/math/clip.cc b/onnxruntime/core/providers/cpu/math/clip.cc index ddb64a5a0e46..200469bc4783 100644 --- a/onnxruntime/core/providers/cpu/math/clip.cc +++ b/onnxruntime/core/providers/cpu/math/clip.cc @@ -23,7 +23,7 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES( float); ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES( kCpuExecutionProvider, kOnnxDomain, Clip, 12, Input, 0, - float, double, int8_t, uint8_t, int32_t, uint32_t, int64_t, uint64_t); + float, MLFloat16, double, int8_t, uint8_t, int32_t, uint32_t, int64_t, uint64_t); } // namespace op_kernel_type_control using EnabledClip11Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 1d524a90302e..5ea6000da1cb 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -705,7 +705,7 @@ Status Min_6::Compute(OpKernelContext* ctx) const { for (int index = 1; index < inputCount; index++) { auto& data_n = *ctx->Input(index); ORT_ENFORCE(data_n.Shape() == shape, "All inputs must have the same shape"); - min = min.array().min(EigenMap(data_n).array()); + min = min.array().template min(EigenMap(data_n).array()); } return Status::OK(); @@ -721,15 +721,16 @@ struct Min_8::ComputeImpl { ProcessBroadcastSpanFuncs funcs{ [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput1().array().min(per_iter_bh.ScalarInput0()); + per_iter_bh.EigenInput1().array().template min(per_iter_bh.ScalarInput0()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().min(per_iter_bh.ScalarInput1()); + per_iter_bh.EigenInput0().array().template min(per_iter_bh.ScalarInput1()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().min(per_iter_bh.EigenInput1().array()); + per_iter_bh.EigenInput0().array().template min( + per_iter_bh.EigenInput1().array()); }}; int input_count = inst.Node().InputArgCount().front(); @@ -827,7 +828,7 @@ Status Max_6::Compute(OpKernelContext* ctx) const { for (int index = 1; index < inputCount; index++) { auto& data_n = *ctx->Input(index); ORT_ENFORCE(data_n.Shape() == shape, "All inputs must have the same shape"); - max = max.array().max(EigenMap(data_n).array()); + max = max.array().template max(EigenMap(data_n).array()); } return Status::OK(); @@ -843,15 +844,16 @@ struct Max_8::ComputeImpl { ProcessBroadcastSpanFuncs funcs{ [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput1().array().max(per_iter_bh.ScalarInput0()); + per_iter_bh.EigenInput1().array().template max(per_iter_bh.ScalarInput0()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().max(per_iter_bh.ScalarInput1()); + per_iter_bh.EigenInput0().array().template max(per_iter_bh.ScalarInput1()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().max(per_iter_bh.EigenInput1().array()); + per_iter_bh.EigenInput0().array().template max( + per_iter_bh.EigenInput1().array()); }}; int input_count = inst.Node().InputArgCount().front(); diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc index 244da35427f4..5aac1d9387f5 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc @@ -123,30 +123,42 @@ namespace onnxruntime { x); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceL1, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceL1, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceL1, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceL1, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 13, 17); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceL1, 13, 17); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceL1, 13, 17); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL1, 18); +REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceL1, 18); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceL1, 18); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceL2, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceL2, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceL2, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceL2, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 13, 17); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceL2, 13, 17); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceL2, 13, 17); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL2, 18); +REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceL2, 18); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceL2, 18); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceLogSum, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceLogSum, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceLogSum, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceLogSum, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 13, 17); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceLogSum, 13, 17); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceLogSum, 13, 17); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceLogSum, 18); +REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceLogSum, 18); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceLogSum, 18); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSumExp, 1, 10); @@ -202,6 +214,10 @@ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceMean, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceMean, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceMean, 13, 17); REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceMean, 18); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceMean, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceMean, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceMean, 13, 17); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceMean, 18); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceMin, 1, 10); @@ -236,12 +252,16 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ReduceMin, 20); REGISTER_UNARY_ELEMENTWISE_KERNEL_BOOL_ONLY(ReduceMin, 20); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceProd, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceProd, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceProd, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceProd, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 13, 17); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ReduceProd, 13, 17); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceProd, 13, 17); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceProd, 18); +REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceProd, 18); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceProd, 18); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSum, 1, 10); diff --git a/onnxruntime/core/providers/cuda/cuda_call.cc b/onnxruntime/core/providers/cuda/cuda_call.cc index c73b23f3762e..511a6e2dce19 100644 --- a/onnxruntime/core/providers/cuda/cuda_call.cc +++ b/onnxruntime/core/providers/cuda/cuda_call.cc @@ -34,7 +34,6 @@ const char* CudaErrString(cudaError_t x) { template <> const char* CudaErrString(cublasStatus_t e) { cudaDeviceSynchronize(); - switch (e) { CASE_ENUM_TO_STR(CUBLAS_STATUS_SUCCESS); CASE_ENUM_TO_STR(CUBLAS_STATUS_NOT_INITIALIZED); @@ -87,9 +86,15 @@ const char* CudaErrString(ncclResult_t e) { } #endif -template +template +int GetErrorCode(ERRTYPE err) { + return static_cast(err); +} + +template std::conditional_t CudaCall( - ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) { + ERRTYPE retCode, const char* exprString, const char* libName, SUCCTYPE successCode, const char* msg, + const char* file, const int line) { if (retCode != successCode) { try { #ifdef _WIN32 @@ -108,7 +113,7 @@ std::conditional_t CudaCall( cudaGetLastError(); // clear last CUDA error static char str[1024]; snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", - libName, (int)retCode, CudaErrString(retCode), currentCudaDevice, + libName, GetErrorCode(retCode), CudaErrString(retCode), currentCudaDevice, hostname, file, line, exprString, msg); if constexpr (THRW) { @@ -118,7 +123,8 @@ std::conditional_t CudaCall( LOGS_DEFAULT(ERROR) << str; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str); } - } catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction, so we'd never get to see the error + } catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction, + // so we'd never get to see the error if constexpr (THRW) { ORT_THROW(e.what()); } else { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 5771380433b3..f74754c3cd06 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -180,6 +180,7 @@ CUDAExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de CUDNN_CALL_THROW(cudnnCreate(&cudnn_handle_)); CUDNN_CALL_THROW(cudnnSetStream(cudnn_handle_, stream)); + LOGS_DEFAULT(INFO) << "cuDNN version: " << cudnnGetVersion(); #endif cuda_graph_.SetStream(stream); } @@ -2469,6 +2470,19 @@ static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const return false; } +static bool NhwcConvNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger, + [[maybe_unused]] const GraphViewer& graph_viewer, + [[maybe_unused]] const bool prefer_nhwc) { + // NHWC implementation doesn't handle W in NHWC layout if it's not an initializer + if (!graph_viewer.IsConstantInitializer(node.InputDefs()[1]->Name(), true)) { + LOGS(logger, WARNING) << "Dropping the NhwcConv node: " << node.Name() + << " to CPU because the Cuda EP requires W as initializer for NHWC operation."; + return true; + } + + return false; +} + static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) { const auto& node_attributes = node.GetAttributes(); // Check attributes @@ -2539,6 +2553,9 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, } else if ("Cast" == node.OpType()) { not_supported = CastNeedFallbackToCPU(node); // cast is not compute heavy, and may be placed outside + } else if ("NhwcConv" == node.OpType()) { + not_supported = NhwcConvNeedFallbackToCPU(node, logger, graph, IsNHWCPreferred()); + force_inside = !not_supported; } if (!force_inside && not_supported) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 9c8a8712ca51..0871f7e4d0a7 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -82,6 +82,7 @@ class CUDAExecutionProvider : public IExecutionProvider { bool GetCudnnConv1dPadToNc1d() const { return info_.cudnn_conv1d_pad_to_nc1d; } bool IsSkipLayerNormInStrictMode() const { return info_.enable_skip_layer_norm_strict_mode; } bool IsNHWCPreferred() const { return info_.prefer_nhwc; } + bool IsFuseConvBias() const { return info_.fuse_conv_bias; } bool UseTF32() const { return info_.use_tf32; } #ifndef DISABLE_CONTRIB_OPS diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index 31cf991a34fc..14195703d596 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -34,6 +34,7 @@ constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_s constexpr const char* kPreferNHWCMode = "prefer_nhwc"; constexpr const char* kUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; constexpr const char* kUseTF32 = "use_tf32"; +constexpr const char* kFuseConvBias = "fuse_conv_bias"; constexpr const char* kSdpaKernel = "sdpa_kernel"; } // namespace provider_option_names @@ -119,6 +120,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) .AddAssignmentToReference(cuda::provider_option_names::kUseTF32, info.use_tf32) .AddAssignmentToReference(cuda::provider_option_names::kSdpaKernel, info.sdpa_kernel) + .AddAssignmentToReference(cuda::provider_option_names::kFuseConvBias, info.fuse_conv_bias) .AddValueParser( cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -173,6 +175,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, {cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)}, + {cuda::provider_option_names::kFuseConvBias, MakeStringWithClassicLocale(info.fuse_conv_bias)}, }; return options; @@ -195,6 +198,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, + {cuda::provider_option_names::kFuseConvBias, MakeStringWithClassicLocale(info.fuse_conv_bias)}, {cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)}, }; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index 0efad80f743d..bfd50ca8d40a 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -78,6 +78,7 @@ struct CUDAExecutionProviderInfo { // By default, enable TF32 to speed up float GEMM/MatMul or cuDNN convolution of float matrices. bool use_tf32{true}; + bool fuse_conv_bias{false}; int sdpa_kernel{0}; @@ -107,7 +108,8 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> { (static_cast(info.enable_skip_layer_norm_strict_mode) << 27) ^ (static_cast(info.prefer_nhwc) << 28) ^ (static_cast(info.use_ep_level_unified_stream) << 29) ^ - (static_cast(info.use_tf32) << 30); + (static_cast(info.use_tf32) << 30) ^ + (static_cast(info.fuse_conv_bias) << 31); onnxruntime::HashCombine(data, value); onnxruntime::HashCombine(info.gpu_mem_limit, value); diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index b1d54e56ded4..83a5d02d16c6 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -219,6 +219,7 @@ struct CUDA_Provider : Provider { info.cudnn_conv_use_max_workspace = params->cudnn_conv_use_max_workspace != 0; info.enable_cuda_graph = params->enable_cuda_graph != 0; info.prefer_nhwc = params->prefer_nhwc; + info.fuse_conv_bias = params->fuse_conv_bias; info.cudnn_conv1d_pad_to_nc1d = params->cudnn_conv1d_pad_to_nc1d != 0; info.tunable_op.enable = params->tunable_op_enable; info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; @@ -262,6 +263,7 @@ struct CUDA_Provider : Provider { cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream; cuda_options.use_tf32 = internal_options.use_tf32; cuda_options.sdpa_kernel = internal_options.sdpa_kernel; + cuda_options.fuse_conv_bias = internal_options.fuse_conv_bias; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index 14b75d2383b5..e9b159516dad 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -215,6 +215,9 @@ void* CudaStream::GetResource(int version, int id) const { case CudaResource::prefer_nhwc_t: return reinterpret_cast(ep_info_.prefer_nhwc); break; + case CudaResource::fuse_conv_bias_t: + return reinterpret_cast(ep_info_.fuse_conv_bias); + break; case CudaResource::use_tf32_t: return reinterpret_cast(ep_info_.use_tf32); break; diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index 31fc63a86d64..72482266a2ee 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -3,6 +3,8 @@ // Licensed under the MIT License. #include +#include +#include #include "core/providers/cuda/cudnn_common.h" #include "core/common/inlined_containers.h" @@ -60,7 +62,25 @@ Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dat dims[1] = gsl::narrow_cast(input_dims[rank - 1]); strides[1] = gsl::narrow_cast(pitches[rank - 1]); } - CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast(rank), dims.data(), strides.data())); + CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast(rank), + dims.data(), strides.data())); + return Status::OK(); +} + +Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dataType, + gsl::span input_strides) { + ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); + + int rank = gsl::narrow_cast(input_dims.size()); + InlinedVector dims(rank); + InlinedVector strides(rank); + + for (int i = 0; i < rank; i++) { + dims[i] = gsl::narrow_cast(input_dims[i]); + strides[i] = gsl::narrow_cast(input_strides[i]); + } + CUDNN_RETURN_IF_ERROR( + cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast(rank), dims.data(), strides.data())); return Status::OK(); } @@ -100,7 +120,8 @@ Status CudnnDataTensor::Set(cudnnDataType_t dataType, const int32_t* seq_lengths) { ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); - // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences + // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, + // so that it will auto fill 0 for the shorter sequences cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED; float padding_fill = 0.0f; CUDNN_RETURN_IF_ERROR(cudnnSetRNNDataDescriptor(tensor_, dataType, layout, @@ -238,6 +259,91 @@ const Float8E5M2 Consts::One = Float8E5M2(1.0f, true); #endif +std::vector generateStrides(const std::vector& shape, bool channels_last) { + // For INT8x4 and INT8x32 we still compute standard strides here to input + // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. + std::vector strides(shape.size()); + int64_t nbDims = strides.size(); + if (nbDims <= 1) { + strides[0] = 1; + return strides; + } + if (channels_last) { + // Here we assume that the format is CUDNN_TENSOR_NHWC + strides[1] = 1; + strides[nbDims - 1] = strides[1] * shape[1]; + for (int64_t d = nbDims - 2; d >= 2; d--) { + strides[d] = strides[d + 1] * shape[d + 1]; + } + strides[0] = strides[2] * shape[2]; + } else { + strides[nbDims - 1] = 1; + for (int64_t d = nbDims - 2; d >= 0; d--) { + strides[d] = strides[d + 1] * shape[d + 1]; + } + } + return strides; +} + +#if !defined(__CUDACC__) +CudnnFeTensor::CudnnFeTensor(const onnxruntime::TensorShapeVector& shape, + const std::string& name, + std::optional dtype, + const bool nhwc) { + std::vector shape_vec; + if (shape.size() == 1) { + shape_vec = {1, shape[0], 1, 1}; + } else if (shape.size() >= 4) { + for (size_t i = 0; i < shape.size(); i++) { + shape_vec.push_back(shape[i]); + } + } else { + ORT_THROW("Invalid tensor shape size, tensor name: ", name, ", shape size: ", shape.size()); + } + auto strides = generateStrides(shape_vec, nhwc); + + if (dtype.has_value()) { + tensor_ = cudnn_frontend::graph::Tensor_attributes() + .set_name(name) + .set_dim(shape_vec) + .set_stride(strides) + .set_data_type(dtype.value()); + } else { + tensor_ = cudnn_frontend::graph::Tensor_attributes().set_name(name).set_dim(shape_vec).set_stride(strides); + } +} + +template +cudnn_frontend::DataType_t CudnnFeTensor::GetDataType() { + return cudnn_frontend::DataType_t::NOT_SET; +} + +template <> +cudnn_frontend::DataType_t CudnnFeTensor::GetDataType() { + return cudnn_frontend::DataType_t::FLOAT; +} + +template <> +cudnn_frontend::DataType_t CudnnFeTensor::GetDataType() { + return cudnn_frontend::DataType_t::HALF; +} + +template <> +cudnn_frontend::DataType_t CudnnFeTensor::GetDataType() { + return cudnn_frontend::DataType_t::DOUBLE; +} + +template <> +cudnn_frontend::DataType_t CudnnFeTensor::GetDataType() { + return cudnn_frontend::DataType_t::INT8; +} + +template <> +cudnn_frontend::DataType_t CudnnFeTensor::GetDataType() { + return cudnn_frontend::DataType_t::UINT8; +} +#endif + } // namespace cuda } // namespace onnxruntime #endif diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index 2cbeb1369627..b267ef6bed64 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -5,12 +5,22 @@ #pragma once #include +#include +#include #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/cudnn_fe_call.h" + #ifndef USE_CUDA_MINIMAL +#if !defined(__CUDACC__) +#include +#endif + namespace onnxruntime { namespace cuda { +#define CUDNN_FE_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDNN_FE_CALL(expr)) + class CudnnTensor final { public: CudnnTensor(); @@ -19,6 +29,7 @@ class CudnnTensor final { Status Set(gsl::span input_dims, cudnnDataType_t dataType, bool is_nhwc = false); Status Set(const CudnnTensor& x_desc, cudnnBatchNormMode_t mode); + Status Set(gsl::span input_dims, cudnnDataType_t dataType, gsl::span input_strides); // Set 4D tensor format (for NHWC) Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int n, int c, int h, int w); @@ -139,7 +150,8 @@ struct Consts { inline double ClampCudnnBatchNormEpsilon(double epsilon) { if (epsilon < CUDNN_BN_MIN_EPSILON) { if (CUDNN_BN_MIN_EPSILON - epsilon > FLT_EPSILON) - LOGS_DEFAULT(WARNING) << "Provided epsilon is smaller than CUDNN_BN_MIN_EPSILON. Setting it to CUDNN_BN_MIN_EPSILON"; + LOGS_DEFAULT(WARNING) << "Provided epsilon is smaller than CUDNN_BN_MIN_EPSILON. " + << "Setting it to CUDNN_BN_MIN_EPSILON"; return CUDNN_BN_MIN_EPSILON; } return epsilon; @@ -258,6 +270,23 @@ SetPoolingNdDescriptorHelper(cudnnPoolingDescriptor_t poolingDesc, return cudnnSetPoolingNdDescriptor(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); } +std::vector generateStrides(const std::vector& shape, bool channels_last); + +#if !defined(__CUDACC__) +class CudnnFeTensor final { + public: + CudnnFeTensor(const onnxruntime::TensorShapeVector& shape, const std::string& name, + std::optional dtype, const bool nhwc); + + template + static cudnn_frontend::DataType_t GetDataType(); + cudnn_frontend::graph::Tensor_attributes Get() { return tensor_; } + + private: + cudnn_frontend::graph::Tensor_attributes tensor_; +}; +#endif + } // namespace cuda } // namespace onnxruntime #endif diff --git a/onnxruntime/core/providers/cuda/cudnn_fe_call.cc b/onnxruntime/core/providers/cuda/cudnn_fe_call.cc new file mode 100644 index 000000000000..640025c24818 --- /dev/null +++ b/onnxruntime/core/providers/cuda/cudnn_fe_call.cc @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/shared_inc/cudnn_fe_call.h" +#include "core/providers/shared_library/provider_api.h" +#include +#if !defined(__CUDACC__) +#include +#endif +#ifdef _WIN32 +#else // POSIX +#include +#include +#endif + +namespace onnxruntime { + +using namespace common; + +template +const char* CudaErrString(ERRTYPE) { + ORT_NOT_IMPLEMENTED(); +} + +#if !defined(__CUDACC__) +#define CASE_ENUM_TO_STR_CUDNN_FE(x) \ + case cudnn_frontend::error_code_t::x: \ + return #x + +template <> +const char* CudaErrString(cudnn_frontend::error_t x) { + cudaDeviceSynchronize(); + LOGS_DEFAULT(ERROR) << x.get_message(); + switch (x.get_code()) { + CASE_ENUM_TO_STR_CUDNN_FE(OK); + CASE_ENUM_TO_STR_CUDNN_FE(ATTRIBUTE_NOT_SET); + CASE_ENUM_TO_STR_CUDNN_FE(SHAPE_DEDUCTION_FAILED); + CASE_ENUM_TO_STR_CUDNN_FE(INVALID_TENSOR_NAME); + CASE_ENUM_TO_STR_CUDNN_FE(INVALID_VARIANT_PACK); + CASE_ENUM_TO_STR_CUDNN_FE(GRAPH_NOT_SUPPORTED); + CASE_ENUM_TO_STR_CUDNN_FE(GRAPH_EXECUTION_PLAN_CREATION_FAILED); + CASE_ENUM_TO_STR_CUDNN_FE(GRAPH_EXECUTION_FAILED); + CASE_ENUM_TO_STR_CUDNN_FE(HEURISTIC_QUERY_FAILED); + CASE_ENUM_TO_STR_CUDNN_FE(UNSUPPORTED_GRAPH_FORMAT); + CASE_ENUM_TO_STR_CUDNN_FE(CUDA_API_FAILED); + CASE_ENUM_TO_STR_CUDNN_FE(CUDNN_BACKEND_API_FAILED); + CASE_ENUM_TO_STR_CUDNN_FE(INVALID_CUDA_DEVICE); + CASE_ENUM_TO_STR_CUDNN_FE(HANDLE_ERROR); + default: + return "Unknown CUDNN_FRONTEND error status"; + } +} + +template +int GetErrorCode(ERRTYPE err) { + return static_cast(err); +} + +template <> +int GetErrorCode(cudnn_frontend::error_t err) { + return static_cast(err.get_code()); +} + +template +std::conditional_t CudaCall( + ERRTYPE retCode, const char* exprString, const char* libName, SUCCTYPE successCode, const char* msg, + const char* file, const int line) { + if (retCode != successCode) { + try { +#ifdef _WIN32 + std::string hostname_str = GetEnvironmentVar("COMPUTERNAME"); + if (hostname_str.empty()) { + hostname_str = "?"; + } + const char* hostname = hostname_str.c_str(); +#else + char hostname[HOST_NAME_MAX]; + if (gethostname(hostname, HOST_NAME_MAX) != 0) + strcpy(hostname, "?"); +#endif + int currentCudaDevice; + cudaGetDevice(¤tCudaDevice); + cudaGetLastError(); // clear last CUDA error + static char str[1024]; + snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", + libName, GetErrorCode(retCode), CudaErrString(retCode), currentCudaDevice, + hostname, + file, line, exprString, msg); + if constexpr (THRW) { + // throw an exception with the error info + ORT_THROW(str); + } else { + LOGS_DEFAULT(ERROR) << str; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str); + } + } catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction, + // so we'd never get to see the error + if constexpr (THRW) { + ORT_THROW(e.what()); + } else { + LOGS_DEFAULT(ERROR) << e.what(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what()); + } + } + } + if constexpr (!THRW) { + return Status::OK(); + } +} + +template Status CudaCall( + cudnn_frontend::error_t retCode, const char* exprString, const char* libName, + cudnn_frontend::error_code_t successCode, const char* msg, const char* file, const int line); +template void CudaCall( + cudnn_frontend::error_t retCode, const char* exprString, const char* libName, + cudnn_frontend::error_code_t successCode, const char* msg, const char* file, const int line); + +#endif +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 764feadcf4cb..95ba698b707a 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -3,14 +3,20 @@ // Licensed under the MIT License. #include +#include +#include +#include "core/common/status.h" #include "core/providers/cuda/nn/conv.h" #include "core/common/span_utils.h" #include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "core/providers/cuda/tensor/slice.h" #include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/shared_inc/cudnn_fe_call.h" +#if CUDNN_MAJOR < 9 +// if compiled with cuDNN 8 we want to use the legacy cuDNN API +#include "conv_8.h" +#endif namespace onnxruntime { namespace cuda { @@ -43,58 +49,7 @@ REGISTER_KERNEL_TYPED(float, kMSInternalNHWCDomain, true) REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) #endif -template -const cudnnConvolutionFwdAlgo_t Conv::kAllAlgos[] = { - CUDNN_CONVOLUTION_FWD_ALGO_GEMM, - CUDNN_CONVOLUTION_FWD_ALGO_FFT, - CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, - CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, - CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, - CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, - CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, - CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, -}; - -cudnnStatus_t GetWorkspaceSize(cudnnHandle_t handle, const CudnnConvState& s, cudnnConvolutionFwdAlgo_t algo, size_t* sz) { - return cudnnGetConvolutionForwardWorkspaceSize(handle, s.x_tensor, s.w_desc, s.conv_desc, s.y_tensor, algo, sz); -} - -size_t GetMaxWorkspaceSize(cudnnHandle_t handle, const CudnnConvState& s, - const cudnnConvolutionFwdAlgo_t* algo, int n_algo) { - // TODO: get maximum available size from memory arena - size_t free, total; - CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); - // Assuming 10% of fragmentation - free = static_cast(static_cast(free) * 0.9); - size_t max_ws_size = 0; - for (int i = 0; i < n_algo; i++) { - cudnnStatus_t err; - size_t sz; - err = GetWorkspaceSize(handle, s, algo[i], &sz); - if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size || sz > free) continue; - max_ws_size = sz; - } - return max_ws_size; -} - -Status SliceOutUnwantedOutputSection(cudaStream_t stream, - const void* input_data, gsl::span input_dims, - void* output_data, - const gsl::span& output_dims, - const gsl::span& starts, - const gsl::span& ends, - const gsl::span& axes, - size_t element_size) { - SliceOp::PrepareForComputeMetadata compute_metadata(input_dims); - - ORT_THROW_IF_ERROR(SliceBase::PrepareForCompute(starts, ends, axes, compute_metadata)); - - // As a sanity check, ensure that the slice operator's output shape matches with the expected output shape - ORT_ENFORCE(SpanEq(gsl::make_span(compute_metadata.output_dims_), output_dims)); - - return SliceCuda::Impl(stream, input_data, input_dims, output_data, compute_metadata, element_size); -} - +// First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW template Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { @@ -104,14 +59,20 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr if (is_nhwc_domain_ && input_idx == 1) { // InputTensors::IN_W // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); + auto shape_size = orig_shape.GetDims().size(); - InlinedVector perm{0, 2, 3, 1}; - gsl::span permutation(perm.data(), 4); - TensorShapeVector new_dims{orig_shape[0], - orig_shape[2], - orig_shape[3], - orig_shape[1]}; - W_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + InlinedVector perm; + perm.push_back(0); + for (size_t i = 2; i < shape_size; i++) perm.push_back(i); + perm.push_back(1); + gsl::span permutation(perm.data(), shape_size); + + TensorShapeVector nhwc_dims; + for (size_t i = 0; i < shape_size; i++) { + nhwc_dims.push_back(orig_shape[perm[i]]); + } + + W_ = Tensor::Create(tensor.DataType(), TensorShape(nhwc_dims), std::move(alloc)); auto status = cuda::Transpose::DoTranspose(GetDeviceProp(), DefaultCudaStream(), @@ -122,6 +83,8 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr } CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); is_packed = true; + } else { + W_already_nhwc = true; } } else { ORT_UNUSED_PARAMETER(tensor); @@ -132,45 +95,205 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr return Status::OK(); } -template -Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { +#if CUDNN_MAJOR >= 9 +#if !defined(__CUDACC__) + +template +Status Conv::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShapeVector& x_dims, + const onnxruntime::TensorShapeVector& w_dims, + const Tensor* B, + const Tensor* Z, + const TensorShapeVector& y_dims, + cudnnContext* handle, + const cudnn_frontend::HeurMode_t heur_mode, + const std::vector& pads, + const std::vector& strides, + const std::vector& dilations, + const bool bias_expected, + const bool fuse_bias, + const bool fuse_act, + const bool w_in_nhwc, + const bool use_tf32) const { + s_.bias_fused = fuse_bias; + s_.act_fused = fuse_act; + s_.variant_pack.clear(); // clear variant pack, as stored pointers to tensors change + s_.cudnn_fe_graph = std::make_unique(); + cudnn_frontend::DataType_t data_type = CudnnFeTensor::GetDataType(); + s_.cudnn_fe_graph->set_io_data_type(data_type).set_intermediate_data_type(data_type); + if (data_type == cudnn_frontend::DataType_t::HALF) { + s_.cudnn_fe_graph->set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + } else { + s_.cudnn_fe_graph->set_compute_data_type(data_type); + } + + s_.cudnn_fe_X = s_.cudnn_fe_graph->tensor(CudnnFeTensor(x_dims, "x", data_type, Layout == LAYOUT_NHWC).Get()); + s_.cudnn_fe_W = s_.cudnn_fe_graph->tensor(CudnnFeTensor(w_dims, "w", data_type, w_in_nhwc).Get()); + + auto conv_options = cudnn_frontend::graph::Conv_fprop_attributes() + .set_pre_padding(std::vector(pads.begin(), + pads.begin() + pads.size() / 2)) + .set_post_padding(std::vector(pads.begin() + pads.size() / 2, pads.end())) + .set_stride(strides) + .set_dilation(dilations); + s_.cudnn_fe_conv_Y = s_.cudnn_fe_graph->conv_fprop(s_.cudnn_fe_X, s_.cudnn_fe_W, conv_options); + auto cudnn_fe_y_tensor = CudnnFeTensor(y_dims, "y", data_type, Layout == LAYOUT_NHWC).Get(); + + if (!bias_expected && B == nullptr) { + s_.cudnn_fe_Y = s_.cudnn_fe_conv_Y; + } else { + int64_t bias_size; + if (B != nullptr) { + bias_size = B->Shape()[0]; + } else { + bias_size = w_dims[0]; + } + + std::optional cudnn_fe_z_tensor; + if (Z) { + const auto& z_shape = Z->Shape().AsShapeVector(); + cudnn_fe_z_tensor = CudnnFeTensor(z_shape, "z", data_type, Layout == LAYOUT_NHWC).Get(); + } else if (fuse_bias && Layout == LAYOUT_NCHW) { + // Z is required for NCHW precompiled kernels in cuDNN + s_.z_data = s_.y_data; + cudnn_fe_z_tensor = cudnn_fe_y_tensor; + } + + if (fuse_bias) { + std::shared_ptr add_output; + if (cudnn_fe_z_tensor.has_value()) { + s_.cudnn_fe_Z = s_.cudnn_fe_graph->tensor(cudnn_fe_z_tensor.value()); + auto add_options = cudnn_frontend::graph::Pointwise_attributes().set_mode(cudnn_frontend::PointwiseMode_t::ADD); + add_output = s_.cudnn_fe_graph->pointwise(s_.cudnn_fe_conv_Y, s_.cudnn_fe_Z, add_options); + } else { + add_output = s_.cudnn_fe_conv_Y; + } + + onnxruntime::TensorShapeVector b_dims; + for (size_t i = 0; i < x_dims.size(); i++) { + b_dims.push_back(i == 1 ? bias_size : 1); + } + auto bias_tensor = CudnnFeTensor(b_dims, "b", data_type, Layout == LAYOUT_NHWC).Get(); + auto bias_options = cudnn_frontend::graph::Pointwise_attributes().set_mode(cudnn_frontend::PointwiseMode_t::ADD); + s_.cudnn_fe_B = s_.cudnn_fe_graph->tensor(bias_tensor); + s_.cudnn_fe_Y = s_.cudnn_fe_graph->pointwise(add_output, s_.cudnn_fe_B, bias_options); + } else { + s_.cudnn_fe_Y = s_.cudnn_fe_conv_Y; + + TensorShapeVector b_dims(y_dims.size(), 1); + TensorShapeVector b_strides(y_dims.size(), 1); + b_dims[1] = bias_size; + b_strides[0] = bias_size; + if (Z) { + ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().AsShapeVector(), + CudnnTensor::GetDataType(), + cudnn_fe_z_tensor->get_stride())); + } + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), b_strides)); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType(), cudnn_fe_y_tensor.get_stride())); + + /* Creating an own CUDNN Frontend graph for the bias addition. + s_.cudnn_fe_bias_graph = std::make_unique(); + s_.cudnn_fe_bias_graph->set_io_data_type(data_type) + .set_compute_data_type(data_type == cudnn_frontend::DataType_t::HALF ? + cudnn_frontend::DataType_t::FLOAT : data_type) + .set_intermediate_data_type(data_type); + s_.cudnn_fe_bias_X = s_.cudnn_fe_bias_graph->tensor(CudnnFeTensor(y_dims, "x", data_type).Get()); + + s_.cudnn_fe_B = s_.cudnn_fe_bias_graph->tensor(bias_tensor); + s_.cudnn_fe_bias_Y = s_.cudnn_fe_bias_graph->pointwise(s_.cudnn_fe_bias_X, s_.cudnn_fe_B, bias_options); + s_.cudnn_fe_bias_Y->set_output(true); + + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->validate()); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->build_operation_graph(handle)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->create_execution_plans({heur_mode})); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->check_support(handle)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->build_plans(handle));*/ + } + } + if (fuse_act && s_.cudnn_fe_act_attr.has_value()) { + auto& activation_attr = s_.cudnn_fe_act_attr.value(); + s_.cudnn_fe_Y = s_.cudnn_fe_graph->pointwise(s_.cudnn_fe_Y, activation_attr); + } + + s_.cudnn_fe_Y->set_dim(cudnn_fe_y_tensor.get_dim()); + s_.cudnn_fe_Y->set_stride(cudnn_fe_y_tensor.get_stride()); + s_.cudnn_fe_Y->set_output(true); + + try { + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->validate()); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_operation_graph(handle)); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->create_execution_plans({heur_mode})); + } catch (const std::exception& ex) { + std::string message = MakeString("Failed to initialize CUDNN Frontend", ex.what(), + "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); + } + + if (!use_tf32) s_.cudnn_fe_graph->deselect_numeric_notes({cudnn_frontend::NumericalNote_t::TENSOR_CORE}); + + try { + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->check_support(handle)); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_plans(handle)); + } catch (const std::exception& ex) { + if (!fuse_bias && !fuse_act && use_tf32) { + std::string message = MakeString("OP not supported by CUDNN Frontend", ex.what(), + "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); + } + + // Try fallback. + return CreateCudnnFeExecutionPlan(x_dims, w_dims, B, Z, y_dims, handle, heur_mode, + pads, strides, dilations, bias_expected, false, false, w_in_nhwc, true); + } + + s_.workspace_bytes = s_.cudnn_fe_graph->get_workspace_size(); + return Status::OK(); +} + +#endif + +template +Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { + constexpr bool channels_last = Layout; + // set X const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); + // X incl. x_dims is in NHWC Format iff. NHWC == true const auto x_dims = x_shape.AsShapeVector(); + s_.x_data = reinterpret_cast(X->Data()); s_.element_size = X->DataType()->Size(); // set W + bool w_in_nhwc; const Tensor* W; if (!W_) { W = context->Input(1); + w_in_nhwc = W_already_nhwc; + // Dims and memory layout are in NCHW format } else { W = W_.get(); + w_in_nhwc = true; + // W got prepacked, therfore if NHWC == true, then dims and memory layout are in NHWC } const TensorShape& w_shape = W->Shape(); - auto w_dims = w_shape.AsShapeVector(); + onnxruntime::TensorShapeVector w_dims = w_shape.AsShapeVector(); s_.w_data = reinterpret_cast(W->Data()); - // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC. - constexpr bool channels_last = NHWC; - if constexpr (channels_last) { - if (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); - } - } - // set B + // Always in NCHW format + const Tensor* B = nullptr; if (context->InputCount() >= 3) { - const Tensor* B = context->Input(2); + B = context->Input(2); s_.b_data = reinterpret_cast(B->Data()); } else { s_.b_data = nullptr; } + // set Z + const Tensor* Z = nullptr; if (context->InputCount() >= 4) { - const Tensor* Z = context->Input(3); - ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().GetDims(), CudnnTensor::GetDataType())); + Z = context->Input(3); s_.z_data = reinterpret_cast(Z->Data()); } else { s_.z_data = nullptr; @@ -183,13 +306,12 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) if (w_dims_changed) { s_.last_w_dims = gsl::make_span(w_dims); - s_.cached_benchmark_results.clear(); } - ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last)); + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, w_in_nhwc)); TensorShapeVector kernel_shape; - ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape, channels_last)); + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape, w_in_nhwc)); const size_t kernel_rank = kernel_shape.size(); @@ -211,59 +333,46 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const int64_t N = X->Shape()[0]; const int64_t M = W->Shape()[0]; - if (channels_last) { + + if constexpr (channels_last) { y_dims.push_back(N); } else { y_dims.insert(y_dims.begin(), {N, M}); } - bool post_slicing_required = false; - TensorShapeVector slice_starts; - slice_starts.reserve(kernel_rank); - - TensorShapeVector slice_ends; - slice_ends.reserve(kernel_rank); - - TensorShapeVector slice_axes; - slice_axes.reserve(kernel_rank); - constexpr size_t spatial_dim_start = channels_last ? 1 : 2; const size_t spatial_dim_end = spatial_dim_start + kernel_rank; TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end); - TensorShapeVector y_dims_with_adjusted_pads(y_dims); - ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(spatial_shape, kernel_shape, - strides, dilations, pads, y_dims, y_dims_with_adjusted_pads, - post_slicing_required, slice_starts, slice_ends, slice_axes, - channels_last)); - if (channels_last) { + ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(spatial_shape, kernel_shape, + strides, dilations, pads, y_dims)); + if constexpr (channels_last) { y_dims.push_back(M); - y_dims_with_adjusted_pads.push_back(M); } - ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size()); s_.y_dims = gsl::make_span(y_dims); - s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads; - s_.post_slicing_required = post_slicing_required; - s_.slice_starts = slice_starts; - s_.slice_ends = slice_ends; - s_.slice_axes = slice_axes; - s_.Y = context->Output(0, TensorShape(s_.y_dims)); - if (post_slicing_required) { - // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. - s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); - s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); - } else { - // No post slicing needed. Fill the output tensor's buffer directly. - s_.y_data = reinterpret_cast(s_.Y->MutableData()); - } + s_.y_data = reinterpret_cast(s_.Y->MutableData()); const CUDAExecutionProvider* cuda_ep = static_cast(this->Info().GetExecutionProvider()); TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; - TensorShapeVector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads; + TensorShapeVector y_dims_cudnn{y_dims.begin(), y_dims.end()}; + TensorShapeVector w_dims_cudnn{w_dims.begin(), w_dims.end()}; + + if constexpr (channels_last) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, *(x_dims_cudnn.end() - 1)); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, *(y_dims_cudnn.end() - 1)); + x_dims_cudnn.erase(x_dims_cudnn.end() - 1); + y_dims_cudnn.erase(y_dims_cudnn.end() - 1); + + if (w_in_nhwc) { + w_dims_cudnn.insert(w_dims_cudnn.begin() + 1, *(w_dims_cudnn.end() - 1)); + w_dims_cudnn.erase(w_dims_cudnn.end() - 1); + } + } + if (kernel_rank < 2) { // TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D] // especially for EXHAUSTIVE algo search which may result in a better algo selection. @@ -276,7 +385,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) if (cuda_ep->GetCudnnConv1dPadToNc1d()) { x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); - w_dims.insert(w_dims.begin() + 2, 1); + w_dims_cudnn.insert(w_dims.begin() + 2, 1); pads.insert(pads.begin() + kernel_rank, 0); pads.insert(pads.begin(), 0); kernel_shape.insert(kernel_shape.begin(), 1); @@ -285,7 +394,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) } else { x_dims_cudnn.push_back(1); y_dims_cudnn.push_back(1); - w_dims.push_back(1); + w_dims_cudnn.push_back(1); pads.insert(pads.begin() + kernel_rank, 0); pads.insert(pads.end(), 0); kernel_shape.push_back(1); @@ -294,188 +403,105 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) } } - if (w_dims_changed) { - if (!channels_last) { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); - } else { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, - CudnnTensor::GetDataType(), - static_cast(w_dims[0]), - static_cast(w_dims[3]), - static_cast(w_dims[1]), - static_cast(w_dims[2]))); - } - } - // We must delay returning early until here so that the weight dims have been cached properly if (s_.Y->Shape().Size() == 0) { return Status::OK(); } - if (channels_last) { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, - CudnnTensor::GetDataType(), - static_cast(x_dims_cudnn[0]), - static_cast(x_dims_cudnn[3]), - static_cast(x_dims_cudnn[1]), - static_cast(x_dims_cudnn[2]))); - - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, - CudnnTensor::GetDataType(), - static_cast(y_dims_cudnn[0]), - static_cast(y_dims_cudnn[3]), - static_cast(y_dims_cudnn[1]), - static_cast(y_dims_cudnn[2]))); - } else { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); + auto handle = GetCudnnHandle(context); + + int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); +#if !defined(__CUDACC__) + cudnn_frontend::HeurMode_t heur_mode; + switch (cudnn_conv_algo) { + case 0: + heur_mode = cudnn_frontend::HeurMode_t::B; + break; + case 1: + heur_mode = cudnn_frontend::HeurMode_t::A; + break; + case 2: + heur_mode = cudnn_frontend::HeurMode_t::FALLBACK; + LOGS_DEFAULT(WARNING) << "OP " << CudaKernel::Node().OpType() << "(" << CudaKernel::Node().Name() + << ") running in Fallback mode. May be extremely slow."; + break; + default: + heur_mode = cudnn_frontend::HeurMode_t::A; + break; } - ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, - gsl::narrow_cast(conv_attrs_.group), - CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType(), - UseTF32())); - - if (context->InputCount() >= 3) { - const Tensor* B = context->Input(2); - const auto& b_shape = B->Shape(); - ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); - TensorShapeVector b_dims(2 + kernel_shape.size(), 1); - b_dims[1] = b_shape[0]; - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); - // s_.b_data = reinterpret_cast(B->Data()); - } else if (bias_expected) { - TensorShapeVector b_dims(2 + kernel_shape.size(), 1); - b_dims[1] = w_dims[0]; - auto malloc_size = b_dims[1] * sizeof(CudaT); - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); - if (s_.b_zero) { - CUDA_CALL_THROW(cudaFree(s_.b_zero)); - s_.b_zero = nullptr; - } - CUDA_CALL_THROW(cudaMalloc(&s_.b_zero, malloc_size)); - CUDA_CALL_THROW(cudaMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); - } - - if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { - // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) { - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); - } else if constexpr (std::is_same::value) { - if (!UseTF32()) { - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); - } - } - - cudnnConvolutionFwdAlgoPerf_t perf; - int algo_count = 1; - int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); - ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", cudnn_conv_algo); - switch (cudnn_conv_algo) { - case 0: { - static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; - size_t max_ws_size = cuda_ep->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetCudnnHandle(context), s_, kAllAlgos, num_algos) - : AlgoSearchWorkspaceSize; - // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. - // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. - IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); - CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx( - GetCudnnHandle(context), - s_.x_tensor, - s_.x_data, - s_.w_desc, - s_.w_data, - s_.conv_desc, - s_.y_tensor, - s_.y_data, - 1, // requestedAlgoCount - &algo_count, // returnedAlgoCount - &perf, - algo_search_workspace.get(), - max_ws_size)); - break; - } - case 1: - CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7( - GetCudnnHandle(context), - s_.x_tensor, - s_.w_desc, - s_.conv_desc, - s_.y_tensor, - 1, // requestedAlgoCount - &algo_count, // returnedAlgoCount - &perf)); - break; - - default: - perf.algo = kDefaultConvAlgo; - CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); - - if constexpr (std::is_same::value) { - perf.mathType = CUDNN_TENSOR_OP_MATH; - } else if (std::is_same::value && !UseTF32()) { - perf.mathType = CUDNN_FMA_MATH; - } else { - perf.mathType = CUDNN_DEFAULT_MATH; - } - } - s_.cached_benchmark_results.insert(x_dims_cudnn, {perf.algo, perf.memory, perf.mathType}); - } - const auto& perf = s_.cached_benchmark_results.at(x_dims_cudnn); - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); - s_.algo = perf.algo; - s_.workspace_bytes = perf.memory; + const auto use_tf32 = cuda_ep->UseTF32(); + // fuse if this op is part of a FusedConv or if the EP is set to fuse ops + const auto fuse_bias = cuda_ep->IsFuseConvBias() || is_fused_node_; + const auto fuse_act = is_fused_node_; + + ORT_RETURN_IF_ERROR(CreateCudnnFeExecutionPlan(x_dims_cudnn, w_dims_cudnn, B, Z, y_dims_cudnn, handle, heur_mode, + std::vector(pads.begin(), + pads.end()), + std::vector(strides.begin(), + strides.end()), + std::vector(dilations.begin(), + dilations.end()), + bias_expected, fuse_bias, fuse_act, w_in_nhwc, use_tf32)); +#endif } else { // set Y s_.Y = context->Output(0, s_.y_dims); if (s_.Y->Shape().Size() == 0) { return Status::OK(); } - if (s_.post_slicing_required) { - s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); - s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); - } else { - s_.y_data = reinterpret_cast(s_.Y->MutableData()); - } + s_.y_data = reinterpret_cast(s_.Y->MutableData()); } return Status::OK(); } -template -Status Conv::ComputeInternal(OpKernelContext* context) const { +template +Status Conv::ComputeInternal(OpKernelContext* context) const { std::lock_guard lock(s_.mutex); ORT_RETURN_IF_ERROR(UpdateState(context)); if (s_.Y->Shape().Size() == 0) { return Status::OK(); } - const auto alpha = Consts::One; - const auto beta = Consts::Zero; - IAllocatorUniquePtr workspace = GetWorkSpace(context->GetComputeStream()); + const auto alpha = onnxruntime::cuda::Consts::One; auto cudnn_handle = GetCudnnHandle(context); - CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(cudnn_handle, - &alpha, - s_.x_tensor, - s_.x_data, - s_.w_desc, - s_.w_data, - s_.conv_desc, - s_.algo, - workspace.get(), - s_.workspace_bytes, - &beta, - s_.y_tensor, - s_.y_data)); - if (nullptr != s_.b_data) { - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.b_tensor, s_.b_data, +#if !defined(__CUDACC__) + s_.variant_pack.insert_or_assign(s_.cudnn_fe_X, const_cast(s_.x_data)); + s_.variant_pack.insert_or_assign(s_.cudnn_fe_W, const_cast(s_.w_data)); + s_.variant_pack.insert_or_assign(s_.cudnn_fe_Y, s_.y_data); + if (s_.bias_fused && s_.b_data != nullptr) { + s_.variant_pack.insert_or_assign(s_.cudnn_fe_B, const_cast(s_.b_data)); + } + if (s_.bias_fused && s_.z_data != nullptr) { + s_.variant_pack.insert_or_assign(s_.cudnn_fe_Z, const_cast(s_.z_data)); + if (Layout == LAYOUT_NCHW && s_.z_data == s_.y_data) { + // memset Z if it's required for a succesful fusion + CUDA_RETURN_IF_ERROR(cudaMemset(s_.y_data, 0, s_.Y->SizeInBytes())); + } + } + auto ws = GetWorkSpace(context->GetComputeStream()); + + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_graph->execute(cudnn_handle, + s_.variant_pack, + ws.get())); + + if (!s_.bias_fused && s_.z_data != nullptr) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.z_tensor, s_.z_data, &alpha, s_.y_tensor, s_.y_data)); } - // To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions - // This may have lead to extra results that are unnecessary and hence we slice that off here - if (s_.post_slicing_required) { - ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection(Stream(context), s_.y_data, gsl::make_span(s_.y_dims_with_adjusted_pads), - s_.Y->MutableDataRaw(), s_.y_dims.GetDims(), s_.slice_starts, - s_.slice_ends, s_.slice_axes, s_.element_size)); + if (!s_.bias_fused && s_.b_data != nullptr) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.b_tensor, s_.b_data, + &alpha, s_.y_tensor, s_.y_data)); + + /* For the standalone bias addition graph. + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_bias_X, s_.y_data); + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_bias_Y, s_.y_data); + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_B, const_cast(s_.b_data)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->execute(cudnn_handle, + s_.variant_pack_bias, + GetWorkSpace(context->GetComputeStream()).get()));*/ } +#endif + return Status::OK(); } @@ -536,6 +562,7 @@ Status CudnnConvolutionDescriptor::Set( return Status::OK(); } +#endif #ifndef DISABLE_CONTRIB_OPS // template instantiation for NhwcConv diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index 3aec654224e3..484d66081018 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -6,6 +6,12 @@ #include #include +#include +#include + +#if !defined(__CUDACC__) +#include +#endif #include "core/platform/ort_mutex.h" #include "core/providers/cuda/cuda_kernel.h" @@ -150,6 +156,24 @@ struct CudnnConvState { CudnnTensor z_tensor; const void* z_data = nullptr; CudnnConvolutionDescriptor conv_desc; + bool bias_fused = true; + bool act_fused = true; + +#if !defined(__CUDACC__) + std::unique_ptr cudnn_fe_graph; + std::unique_ptr cudnn_fe_bias_graph; + std::shared_ptr cudnn_fe_X; + std::shared_ptr cudnn_fe_W; + std::shared_ptr cudnn_fe_conv_Y; + std::shared_ptr cudnn_fe_Z; + std::shared_ptr cudnn_fe_B; + std::shared_ptr cudnn_fe_Y; + + std::optional cudnn_fe_act_attr = std::nullopt; + + std::unordered_map, void*> variant_pack; + std::unordered_map, void*> variant_pack_bias; +#endif struct PerfResultParams { decltype(AlgoPerfType().algo) algo; @@ -183,7 +207,7 @@ enum : size_t { // ONNX Conv operator uses NCHW format for input, weights and output. // NhwcConv contrib ops uses NHWC format: last dimension of input, weights and output are channels. -template +template class Conv : public CudaKernel { public: using CudaT = typename ToCudaType::MappedType; @@ -205,12 +229,32 @@ class Conv : public CudaKernel { } Status UpdateState(OpKernelContext* context, bool bias_expected = false) const; + +#if !defined(__CUDACC__) && CUDNN_MAJOR >= 9 + Status CreateCudnnFeExecutionPlan(const onnxruntime::TensorShapeVector& x_dims, + const onnxruntime::TensorShapeVector& w_dims, + const Tensor* B, + const Tensor* Z, + const TensorShapeVector& y_dims, + cudnnContext* handle, + const cudnn_frontend::HeurMode_t heur_mode, + const std::vector& pads, + const std::vector& strides, + const std::vector& dilations, + const bool bias_expected, + const bool fuse_bias, + const bool fuse_act, + const bool w_in_nhwc, + const bool use_tf32) const; +#endif + ConvAttributes conv_attrs_; mutable CudnnConvState s_; constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; - static const cudnnConvolutionFwdAlgo_t kAllAlgos[]; std::unique_ptr W_; - bool is_nhwc_domain_; // prepack is only needed for the Conv in kMSInternalNHWCDomain + bool is_nhwc_domain_; // prepack is only needed for the Conv in kMSInternalNHWCDomain + bool is_fused_node_ = false; // ensures the node is fused although the session option is not set + bool W_already_nhwc = false; // In case NHWC == true and Conv is not in kMSInternalNHWCDomain }; Status SliceOutUnwantedOutputSection(cudaStream_t stream, diff --git a/onnxruntime/core/providers/cuda/nn/conv_8.h b/onnxruntime/core/providers/cuda/nn/conv_8.h new file mode 100644 index 000000000000..10239d09041f --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/conv_8.h @@ -0,0 +1,484 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. +// Licensed under the MIT License. + +#include "core/common/status.h" +#include "core/providers/cuda/nn/conv.h" +#include "core/common/span_utils.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cuda/shared_inc/cudnn_fe_call.h" +#include "core/providers/cuda/tensor/slice.h" + +namespace onnxruntime { +namespace cuda { + +static const cudnnConvolutionFwdAlgo_t kAllAlgos[] = { + CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_FFT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, +}; + +static cudnnStatus_t GetWorkspaceSize(cudnnHandle_t handle, const CudnnConvState& s, cudnnConvolutionFwdAlgo_t algo, size_t* sz) { + return cudnnGetConvolutionForwardWorkspaceSize(handle, s.x_tensor, s.w_desc, s.conv_desc, s.y_tensor, algo, sz); +} + +size_t GetMaxWorkspaceSize(cudnnHandle_t handle, const CudnnConvState& s, + const cudnnConvolutionFwdAlgo_t* algo, int n_algo) { + // TODO: get maximum available size from memory arena + size_t free, total; + CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); + // Assuming 10% of fragmentation + free = static_cast(static_cast(free) * 0.9); + size_t max_ws_size = 0; + for (int i = 0; i < n_algo; i++) { + cudnnStatus_t err; + size_t sz; + err = GetWorkspaceSize(handle, s, algo[i], &sz); + if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size || sz > free) continue; + max_ws_size = sz; + } + return max_ws_size; +} + +Status SliceOutUnwantedOutputSection(cudaStream_t stream, + const void* input_data, gsl::span input_dims, + void* output_data, + const gsl::span& output_dims, + const gsl::span& starts, + const gsl::span& ends, + const gsl::span& axes, + size_t element_size) { + SliceOp::PrepareForComputeMetadata compute_metadata(input_dims); + + ORT_THROW_IF_ERROR(SliceBase::PrepareForCompute(starts, ends, axes, compute_metadata)); + + // As a sanity check, ensure that the slice operator's output shape matches with the expected output shape + ORT_ENFORCE(SpanEq(gsl::make_span(compute_metadata.output_dims_), output_dims)); + + return SliceCuda::Impl(stream, input_data, input_dims, output_data, compute_metadata, element_size); +} + +template +Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { + // set X + const Tensor* X = context->Input(0); + const TensorShape& x_shape = X->Shape(); + const auto x_dims = x_shape.AsShapeVector(); + s_.x_data = reinterpret_cast(X->Data()); + s_.element_size = X->DataType()->Size(); + bool w_in_nhwc; + const Tensor* W; + if (!W_) { + W = context->Input(1); + w_in_nhwc = W_already_nhwc; + // Dims and memory layout are in NCHW format + } else { + W = W_.get(); + w_in_nhwc = true; + // W got prepacked, therfore if NHWC == true, then dims and memory layout are in NHWC + } + const TensorShape& w_shape = W->Shape(); + auto w_dims = w_shape.AsShapeVector(); + s_.w_data = reinterpret_cast(W->Data()); + + // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC. + constexpr bool channels_last = NHWC; + if constexpr (channels_last) { + if (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + } + } + + // set B + if (context->InputCount() >= 3) { + const Tensor* B = context->Input(2); + s_.b_data = reinterpret_cast(B->Data()); + } else { + s_.b_data = nullptr; + } + // set Z + if (context->InputCount() >= 4) { + const Tensor* Z = context->Input(3); + ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().GetDims(), CudnnTensor::GetDataType())); + s_.z_data = reinterpret_cast(Z->Data()); + } else { + s_.z_data = nullptr; + } + bool input_dims_changed = (s_.last_x_dims != x_dims); + bool w_dims_changed = (s_.last_w_dims != w_dims); + if (input_dims_changed || w_dims_changed) { + if (input_dims_changed) + s_.last_x_dims = gsl::make_span(x_dims); + + if (w_dims_changed) { + s_.last_w_dims = gsl::make_span(w_dims); + s_.cached_benchmark_results.clear(); + } + + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, w_in_nhwc)); + + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape, w_in_nhwc)); + + const size_t kernel_rank = kernel_shape.size(); + + ConvPadVector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(kernel_rank * 2, 0); + } + TensorShapeVector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(kernel_rank, 1); + } + TensorShapeVector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(kernel_rank, 1); + } + + TensorShapeVector y_dims; + y_dims.reserve(2 + kernel_rank); // add 2 to account for 'N' and 'C' + + const int64_t N = X->Shape()[0]; + const int64_t M = W->Shape()[0]; + if (channels_last) { + y_dims.push_back(N); + } else { + y_dims.insert(y_dims.begin(), {N, M}); + } + + bool post_slicing_required = false; + TensorShapeVector slice_starts; + slice_starts.reserve(kernel_rank); + + TensorShapeVector slice_ends; + slice_ends.reserve(kernel_rank); + + TensorShapeVector slice_axes; + slice_axes.reserve(kernel_rank); + + constexpr size_t spatial_dim_start = channels_last ? 1 : 2; + const size_t spatial_dim_end = spatial_dim_start + kernel_rank; + TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end); + + TensorShapeVector y_dims_with_adjusted_pads(y_dims); + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(spatial_shape, kernel_shape, + strides, dilations, pads, y_dims, y_dims_with_adjusted_pads, + post_slicing_required, slice_starts, slice_ends, slice_axes, + channels_last)); + if (channels_last) { + y_dims.push_back(M); + y_dims_with_adjusted_pads.push_back(M); + } + + ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size()); + s_.y_dims = gsl::make_span(y_dims); + s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads; + s_.post_slicing_required = post_slicing_required; + s_.slice_starts = slice_starts; + s_.slice_ends = slice_ends; + s_.slice_axes = slice_axes; + + s_.Y = context->Output(0, TensorShape(s_.y_dims)); + if (post_slicing_required) { + // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. + s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); + s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); + } else { + // No post slicing needed. Fill the output tensor's buffer directly. + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + } + + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); + + TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; + TensorShapeVector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads; + if (kernel_rank < 2) { + // TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D] + // especially for EXHAUSTIVE algo search which may result in a better algo selection. + // ORTModule uses different algo search options (HEURISTIC, and use max workspace size) compared to + // inference build (EXHAUSTIVE, 32M workspace size). We observed better perf when we pad input shape + // [N,C,D] to [N,C,1,D], especially on A100, and especially for ConvGrad. + // PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems + // to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT. + // See PR #7348 and #7702 for more context. + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); + w_dims.insert(w_dims.begin() + 2, 1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + x_dims_cudnn.push_back(1); + y_dims_cudnn.push_back(1); + w_dims.push_back(1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + } + + if (w_dims_changed) { + if (!channels_last) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + } else if (w_in_nhwc) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, + CudnnTensor::GetDataType(), + static_cast(w_dims[0]), + static_cast(w_dims[3]), + static_cast(w_dims[1]), + static_cast(w_dims[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, + CudnnTensor::GetDataType(), + static_cast(w_dims[0]), + static_cast(w_dims[1]), + static_cast(w_dims[2]), + static_cast(w_dims[3]))); + } + } + + // We must delay returning early until here so that the weight dims have been cached properly + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + + if (channels_last) { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, + CudnnTensor::GetDataType(), + static_cast(x_dims_cudnn[0]), + static_cast(x_dims_cudnn[3]), + static_cast(x_dims_cudnn[1]), + static_cast(x_dims_cudnn[2]))); + + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, + CudnnTensor::GetDataType(), + static_cast(y_dims_cudnn[0]), + static_cast(y_dims_cudnn[3]), + static_cast(y_dims_cudnn[1]), + static_cast(y_dims_cudnn[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); + } + + ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + gsl::narrow_cast(conv_attrs_.group), + CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType(), + UseTF32())); + + if (context->InputCount() >= 3) { + const Tensor* B = context->Input(2); + const auto& b_shape = B->Shape(); + ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); + TensorShapeVector b_dims(2 + kernel_shape.size(), 1); + b_dims[1] = b_shape[0]; + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); + // s_.b_data = reinterpret_cast(B->Data()); + } else if (bias_expected) { + TensorShapeVector b_dims(2 + kernel_shape.size(), 1); + b_dims[1] = w_dims[0]; + auto malloc_size = b_dims[1] * sizeof(CudaT); + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); + if (s_.b_zero) { + CUDA_CALL_THROW(cudaFree(s_.b_zero)); + s_.b_zero = nullptr; + } + CUDA_CALL_THROW(cudaMalloc(&s_.b_zero, malloc_size)); + CUDA_CALL_THROW(cudaMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); + } + + if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { + // set math type to tensor core before algorithm search + if constexpr (std::is_same::value) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } + + cudnnConvolutionFwdAlgoPerf_t perf; + int algo_count = 1; + int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); + ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", cudnn_conv_algo); + switch (cudnn_conv_algo) { + case 0: { + static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; + size_t max_ws_size = cuda_ep->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetCudnnHandle(context), s_, kAllAlgos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx( + GetCudnnHandle(context), + s_.x_tensor, + s_.x_data, + s_.w_desc, + s_.w_data, + s_.conv_desc, + s_.y_tensor, + s_.y_data, + 1, // requestedAlgoCount + &algo_count, // returnedAlgoCount + &perf, + algo_search_workspace.get(), + max_ws_size)); + break; + } + case 1: + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7( + GetCudnnHandle(context), + s_.x_tensor, + s_.w_desc, + s_.conv_desc, + s_.y_tensor, + 1, // requestedAlgoCount + &algo_count, // returnedAlgoCount + &perf)); + break; + + default: + perf.algo = kDefaultConvAlgo; + CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); + + if constexpr (std::is_same::value) { + perf.mathType = CUDNN_TENSOR_OP_MATH; + } else if (std::is_same::value && !UseTF32()) { + perf.mathType = CUDNN_FMA_MATH; + } else { + perf.mathType = CUDNN_DEFAULT_MATH; + } + } + s_.cached_benchmark_results.insert(x_dims_cudnn, {perf.algo, perf.memory, perf.mathType}); + } + const auto& perf = s_.cached_benchmark_results.at(x_dims_cudnn); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); + s_.algo = perf.algo; + s_.workspace_bytes = perf.memory; + } else { + // set Y + s_.Y = context->Output(0, s_.y_dims); + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + if (s_.post_slicing_required) { + s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); + s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); + } else { + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + } + } + return Status::OK(); +} + +template +Status Conv::ComputeInternal(OpKernelContext* context) const { + std::lock_guard lock(s_.mutex); + ORT_RETURN_IF_ERROR(UpdateState(context)); + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + const auto alpha = Consts::One; + const auto beta = Consts::Zero; + IAllocatorUniquePtr workspace = GetWorkSpace(context->GetComputeStream()); + auto cudnn_handle = GetCudnnHandle(context); + CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(cudnn_handle, + &alpha, + s_.x_tensor, + s_.x_data, + s_.w_desc, + s_.w_data, + s_.conv_desc, + s_.algo, + workspace.get(), + s_.workspace_bytes, + &beta, + s_.y_tensor, + s_.y_data)); + if (nullptr != s_.b_data) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.b_tensor, s_.b_data, + &alpha, s_.y_tensor, s_.y_data)); + } + // To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions + // This may have lead to extra results that are unnecessary and hence we slice that off here + if (s_.post_slicing_required) { + ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection(Stream(context), s_.y_data, gsl::make_span(s_.y_dims_with_adjusted_pads), + s_.Y->MutableDataRaw(), s_.y_dims.GetDims(), s_.slice_starts, + s_.slice_ends, s_.slice_axes, s_.element_size)); + } + return Status::OK(); +} + +CudnnConvolutionDescriptor::CudnnConvolutionDescriptor() : desc_(nullptr) { +} + +CudnnConvolutionDescriptor::~CudnnConvolutionDescriptor() { + if (desc_ != nullptr) { + cudnnDestroyConvolutionDescriptor(desc_); + desc_ = nullptr; + } +} + +Status CudnnConvolutionDescriptor::Set( + size_t rank, + const gsl::span& pads, + const gsl::span& strides, + const gsl::span& dilations, + int groups, + cudnnConvolutionMode_t mode, + cudnnDataType_t data_type, + bool use_tf32) { + if (!desc_) + CUDNN_RETURN_IF_ERROR(cudnnCreateConvolutionDescriptor(&desc_)); + + InlinedVector pad_dims(rank); + InlinedVector stride_dims(rank); + InlinedVector dilation_dims(rank); + for (size_t i = 0; i < rank; i++) { + pad_dims[i] = gsl::narrow_cast(pads[i]); + stride_dims[i] = gsl::narrow_cast(strides[i]); + dilation_dims[i] = gsl::narrow_cast(dilations[i]); + } + + // This piece of code is copied from /pytorch/aten/src/ATen/cudnn/Descriptors.h + // Setting math_type to CUDNN_DATA_FLOAT for half input + cudnnDataType_t math_type = data_type; + if (data_type == CUDNN_DATA_HALF) math_type = CUDNN_DATA_FLOAT; + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionNdDescriptor( + desc_, + gsl::narrow_cast(rank), + pad_dims.data(), + stride_dims.data(), + dilation_dims.data(), + mode, + math_type)); + + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(desc_, groups)); + + // Copied from /pytorch/aten/src/ATen/cudnn/Descriptors.h + // See Note [behavior of cudnnFind and cudnnGet] at /pytorch/aten/src/ATen/native/cudnn/Conv_v7.cpp + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH)); + if (data_type == CUDNN_DATA_HALF) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH)); + } else if (data_type == CUDNN_DATA_FLOAT && !use_tf32) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_FMA_MATH)); + } + + return Status::OK(); +} +} // namespace cuda +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h index 51a5631b930b..2b2b726e62c7 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h @@ -4,16 +4,16 @@ #pragma once #include "core/common/common.h" #include "core/providers/cuda/cuda_pch.h" - namespace onnxruntime { // ----------------------------------------------------------------------- // Error handling // ----------------------------------------------------------------------- -template +template std::conditional_t CudaCall( - ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line); + ERRTYPE retCode, const char* exprString, const char* libName, SUCCTYPE successCode, const char* msg, + const char* file, const int line); #define CUDA_CALL(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) #define CUBLAS_CALL(expr) (CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) diff --git a/onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h b/onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h new file mode 100644 index 000000000000..a51d84a7efa5 --- /dev/null +++ b/onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_pch.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" +#if !defined(__CUDACC__) +#include +#endif +namespace onnxruntime { + +// ----------------------------------------------------------------------- +// Error handling +// ----------------------------------------------------------------------- + +#define CUDNN_FE_CALL(expr) (CudaCall((cudnn_frontend::error_t)(expr), #expr, "CUDNN_FE", \ + cudnn_frontend::error_code_t::OK, "", __FILE__, __LINE__)) +#define CUDNN_FE_CALL_THROW(expr) (CudaCall((cudnn_frontend::error_t)(expr), #expr, "CUDNN_FE", \ + cudnn_frontend::error_code_t::OK, "", __FILE__, __LINE__)) +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h index 053c66ddcb34..240272923a3a 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h @@ -15,9 +15,6 @@ #include "core/providers/cuda/cuda_common.h" -using namespace onnxruntime; -using namespace onnxruntime::cuda; - // Generalize library calls to be use in template functions inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, @@ -84,7 +81,7 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, half* C, int ldc, const cudaDeviceProp& prop, bool /*use_tf32*/) { - const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); + const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { return cublasGemmEx(handle, @@ -127,7 +124,7 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, half* C, int ldc, const cudaDeviceProp& prop, bool /*use_tf32*/) { - const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); + const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { // The alpha and beta shall have same precision as compute type. @@ -162,8 +159,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, #if defined(USE_CUDA) inline cublasStatus_t cublasGemmHelper( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, - int n, int k, const BFloat16* alpha, const BFloat16* A, int lda, - const BFloat16* B, int ldb, const BFloat16* beta, BFloat16* C, int ldc, + int n, int k, const onnxruntime::BFloat16* alpha, const onnxruntime::BFloat16* A, int lda, + const onnxruntime::BFloat16* B, int ldb, const onnxruntime::BFloat16* beta, onnxruntime::BFloat16* C, int ldc, const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -174,8 +171,9 @@ inline cublasStatus_t cublasGemmHelper( } #else inline cublasStatus_t cublasGemmHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, - const BFloat16*, const BFloat16*, int, const BFloat16*, int, const BFloat16*, - BFloat16*, int, const cudaDeviceProp&, bool /*use_tf32*/) { + const onnxruntime::BFloat16*, const onnxruntime::BFloat16*, int, + const onnxruntime::BFloat16*, int, const onnxruntime::BFloat16*, + onnxruntime::BFloat16*, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif @@ -250,7 +248,7 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, int batch_count, const cudaDeviceProp& prop, bool /*use_tf32*/) { - const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); + const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { return cublasGemmBatchedEx(handle, @@ -286,9 +284,9 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, #if defined(USE_CUDA) inline cublasStatus_t cublasGemmBatchedHelper( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, const BFloat16* alpha, const BFloat16* Aarray[], - int lda, const BFloat16* Barray[], int ldb, const BFloat16* beta, - BFloat16* Carray[], int ldc, int batch_count, + int m, int n, int k, const onnxruntime::BFloat16* alpha, const onnxruntime::BFloat16* Aarray[], + int lda, const onnxruntime::BFloat16* Barray[], int ldb, const onnxruntime::BFloat16* beta, + onnxruntime::BFloat16* Carray[], int ldc, int batch_count, const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -300,8 +298,9 @@ inline cublasStatus_t cublasGemmBatchedHelper( } #else inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, - const BFloat16*, const BFloat16*[], int, const BFloat16*[], int, - const BFloat16*, BFloat16*[], int, int, const cudaDeviceProp&, + const onnxruntime::BFloat16*, const onnxruntime::BFloat16*[], int, + const onnxruntime::BFloat16*[], int, const onnxruntime::BFloat16*, + onnxruntime::BFloat16*[], int, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } @@ -314,12 +313,12 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, int m, int n, int k, const float* alpha, const float* A, int lda, - long long int strideA, + int64_t strideA, const float* B, int ldb, - long long int strideB, + int64_t strideB, const float* beta, float* C, int ldc, - long long int strideC, + int64_t strideC, int batch_count, const cudaDeviceProp& prop, bool use_tf32) { @@ -349,12 +348,12 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, int m, int n, int k, const double* alpha, const double* A, int lda, - long long int strideA, + int64_t strideA, const double* B, int ldb, - long long int strideB, + int64_t strideB, const double* beta, double* C, int ldc, - long long int strideC, + int64_t strideC, int batch_count, const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { @@ -376,16 +375,16 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, int m, int n, int k, const __half* alpha, const __half* A, int lda, - long long int strideA, + int64_t strideA, const __half* B, int ldb, - long long int strideB, + int64_t strideB, const __half* beta, __half* C, int ldc, - long long int strideC, + int64_t strideC, int batch_count, const cudaDeviceProp& prop, bool /*use_tf32*/) { - const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); + const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { return cublasGemmStridedBatchedEx(handle, @@ -425,16 +424,16 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, int m, int n, int k, const float* alpha, const __half* A, int lda, - long long int strideA, + int64_t strideA, const __half* B, int ldb, - long long int strideB, + int64_t strideB, const float* beta, __half* C, int ldc, - long long int strideC, + int64_t strideC, int batch_count, const cudaDeviceProp& prop, bool /*use_tf32*/) { - const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); + const onnxruntime::cuda::HalfGemmOptions* half_options = onnxruntime::cuda::HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { // The alpha and beta shall have same precision as compute type. @@ -472,10 +471,10 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, inline cublasStatus_t cublasGemmStridedBatchedHelper( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, - const BFloat16* alpha, const BFloat16* A, int lda, - long long int strideA, const BFloat16* B, int ldb, - long long int strideB, const BFloat16* beta, BFloat16* C, int ldc, - long long int strideC, int batch_count, + const onnxruntime::BFloat16* alpha, const onnxruntime::BFloat16* A, int lda, + int64_t strideA, const onnxruntime::BFloat16* B, int ldb, + int64_t strideB, const onnxruntime::BFloat16* beta, onnxruntime::BFloat16* C, int ldc, + int64_t strideC, int batch_count, const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -488,9 +487,9 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper( #else inline cublasStatus_t cublasGemmStridedBatchedHelper( cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, - int, const BFloat16*, const BFloat16*, int, long long int, - const BFloat16*, int, long long int, const BFloat16*, BFloat16*, - int, long long int, int, const cudaDeviceProp&, bool /*use_tf32*/) { + int, const onnxruntime::BFloat16*, const onnxruntime::BFloat16*, int, int64_t, + const onnxruntime::BFloat16*, int, int64_t, const onnxruntime::BFloat16*, onnxruntime::BFloat16*, + int, int64_t, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif @@ -531,4 +530,5 @@ cublasStatus_t cublasCopyHelper( cudaStream_t stream, cublasHandle_t handle, int n, const half* x, int incx, half* y, int incy); cublasStatus_t cublasCopyHelper( - cudaStream_t stream, cublasHandle_t handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy); + cudaStream_t stream, cublasHandle_t handle, int n, const onnxruntime::BFloat16* x, + int incx, onnxruntime::BFloat16* y, int incy); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 27605a6ad8e8..cf8f0a4b2db8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -977,6 +977,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, {REG_INFO( 13, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, {REG_INFO( 18, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO( 20, ReduceMin, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported)}, {REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported)}, {REG_INFO( 12, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index cd188761b22f..f45c2b08db94 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -434,6 +434,7 @@ namespace OperatorHelper static const int sc_sinceVer_IsNaN = 20; static const int sc_sinceVer_IsInf = 20; static const int sc_sinceVer_ReduceMax = 20; + static const int sc_sinceVer_ReduceMin = 20; } namespace MsftOperatorSet1 diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 097b16ecde53..314e278695c4 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include "core/providers/shared_library/provider_api.h" #define ORT_API_MANUAL_INIT @@ -990,6 +991,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string onnx_string_buffer; model_proto->SerializeToString(onnx_string_buffer); + model_path_ = graph_viewer.ModelPath(); // dump onnx file if environment var is set if (dump_model_ops_) { @@ -1168,7 +1170,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& auto param_shapes = prog.get_parameter_shapes(); // Add all calibration data read in from int8 table - for (auto& [cal_key, cal_val] : dynamic_range_map) { + for (auto& [cal_key, cal_val] : dynamic_range_map_) { auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); } @@ -1217,7 +1219,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, - int8_calibration_cache_available_, dynamic_range_map, + int8_calibration_cache_available_, dynamic_range_map_, save_compiled_model_, save_compiled_path_, load_compiled_model_, load_compiled_path_, dump_model_ops_}; *state = p.release(); @@ -1297,6 +1299,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (!input_shape_match) { if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl; + cmp_options.set_external_data_path(model_path_.has_parent_path() ? model_path_.parent_path().string() : std::filesystem::current_path().string()); prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); // Read in the calibration data and map it to an migraphx paramater map for the calibration ops diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index f34ca320d0a5..21b582de8f86 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -11,6 +11,7 @@ #include #include +#include namespace onnxruntime { @@ -91,7 +92,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { bool int8_calibration_cache_available_ = false; bool int8_use_native_migraphx_calibration_table_ = false; std::string calibration_cache_path_; - std::unordered_map dynamic_range_map; + std::unordered_map dynamic_range_map_; bool save_compiled_model_ = false; std::string save_compiled_path_; bool load_compiled_model_ = false; @@ -100,6 +101,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { migraphx::target t_; OrtMutex mgx_mu_; hipStream_t stream_ = nullptr; + mutable std::filesystem::path model_path_; std::unordered_map map_progs_; std::unordered_map map_onnx_string_; diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 1c027e39fa5f..18a6257910a5 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -28,9 +28,8 @@ BackendManager::BackendManager(const GlobalContext& global_context, const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger, - EPCtxHandler& ctx_handle) { + EPCtxHandler& ep_ctx_handle_) { global_context_ = global_context; - ep_ctx_handle_ = ctx_handle; openvino_sdk_version_ = std::to_string(global_context_.OpenVINO_Version.at(0)) + "." + std::to_string(global_context_.OpenVINO_Version.at(1)); @@ -129,6 +128,13 @@ BackendManager::BackendManager(const GlobalContext& global_context, #endif } } + if (global_context_.export_ep_ctx_blob && !ep_ctx_handle_.IsValidOVEPCtxGraph()) { + auto status = onnxruntime::openvino_ep::BackendManager::ExportCompiledBlobAsEPCtxNode(subgraph, + logger); + if ((!status.IsOK())) { + ORT_THROW(status); + } + } } // Call EPContext model exporter here if the provider option for exporting @@ -147,13 +153,20 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie std::string model_blob_str; auto compiled_model = concrete_backend_->GetOVCompiledModel(); - auto graph_name = global_context_.onnx_model_path_name; - // Remove extension so we can append suffix to form the complete name of output graph - graph_name = [&]() { - size_t dot = graph_name.find_last_of("."); - if (dot == std::string::npos) return graph_name; - return graph_name.substr(0, dot); - }(); + std::string graph_name = ""; + // Epctx file path from SO is mapped to cache_dir variable for OVEP for readability + if (global_context_.cache_dir != "") { + graph_name = global_context_.cache_dir; + } else { + graph_name = global_context_.onnx_model_path_name; + // Remove extension so we can append suffix to form the complete name of output graph + graph_name = [&]() { + size_t dot = graph_name.find_last_of("."); + if (dot == std::string::npos) return graph_name; + return graph_name.substr(0, dot); + }(); + graph_name = graph_name + "_ctx.onnx"; + } // If embed_mode, then pass on the serialized blob // If not embed_mode, dump the blob here and only pass on the path to the blob if (global_context_.ep_context_embed_mode) { @@ -162,9 +175,19 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie model_blob_str = model_blob_stream.str(); ORT_ENFORCE(model_blob_str.size() != 0); } else { - std::ofstream f(graph_name + ".blob", std::ios::out | std::ios::trunc | std::ios::binary); - compiled_model.export_model(f); - model_blob_str = graph_name + ".blob"; + // Remove extension so we can append suffix to form the complete name of output graph + auto blob_name = [&]() { + size_t dot = graph_name.find_last_of("."); + if (dot == std::string::npos) return graph_name; + return graph_name.substr(0, dot); + }(); + std::ofstream blob_file(blob_name + ".blob", + std::ios::out | std::ios::trunc | std::ios::binary); + if (!blob_file) { + ORT_THROW("Unable to open file for epctx model dump."); + } + compiled_model.export_model(blob_file); + model_blob_str = blob_name + ".blob"; } ORT_RETURN_IF_ERROR(ep_ctx_handle_.ExportEPCtxModel(graph_body_viewer, @@ -172,8 +195,7 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie logger, global_context_.ep_context_embed_mode, model_blob_str, - openvino_sdk_version_, - GetGlobalContext().device_type)); + openvino_sdk_version_)); return Status::OK(); } @@ -248,7 +270,7 @@ static void DumpOpenVINOEPModel(std::string onnx_model_path_name, ONNX_NAMESPACE::ModelProto* model_proto, const onnxruntime::Node& fused_node) { if (openvino_ep::backend_utils::IsDebugEnabled()) { - auto model_name = onnx_model_path_name.empty() ? "unknown.onnx" : onnx_model_path_name; + auto model_name = onnx_model_path_name.empty() ? "unknown.onnx" : std::move(onnx_model_path_name); #ifdef _WIN32 size_t slash = model_name.find_last_of("\\"); #else diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index f8046bcb3a06..d79aa35be641 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -37,7 +37,7 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, PopulateConfigValue(device_config); // Enable caching - EnableCaching(); + EnableCaching(device_config); // Setting OpenCL queue throttling for GPU EnableGPUThrottling(device_config); @@ -82,26 +82,28 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); } #else // !IO_BUFFER_ENABLED + std::string prec_str = (global_context_.precision_str != "ACCURACY") ? global_context_.precision_str : global_context_.model_precision; if (is_ep_ctx_graph_) { // If the blob is held in an EPContext node, then skip FE+Compile // and directly move on to creating a backend with the executable blob exe_network_ = global_context_.ie_core.ImportModel(ep_ctx_handle.GetModelBlobStream(), hw_target, device_config, + global_context_.ep_context_embed_mode, subgraph_context_.subgraph_name); ie_cnn_network_ = exe_network_.Get().get_runtime_model(); - } else if (!subgraph_context_.has_dynamic_input_shape) { + } else if ((!subgraph_context_.has_dynamic_input_shape) && + ((hw_target.find("AUTO") == std::string::npos) || + (global_context_.OpenVINO_Version.at(0) >= 2024 && global_context_.OpenVINO_Version.at(1) > 2))) { + // Optimized OV compile_model API is supported with AUTO from version 2024.3 and above // Inputs with static dimenstions - std::string prec_str = (global_context_.precision_str != "ACCURACY") ? global_context_.precision_str : global_context_.model_precision; const std::string model = model_proto.SerializeAsString(); exe_network_ = global_context_.ie_core.CompileModel(model, hw_target, - prec_str, - global_context_.cache_dir, device_config, subgraph_context_.subgraph_name); ie_cnn_network_ = exe_network_.Get().get_runtime_model(); - } else { // Inputs with dynamic dimensions + } else { // For all other types use ov::Model Type ie_cnn_network_ = CreateOVModel(model_proto, global_context_, const_outputs_map_); exe_network_ = global_context_.ie_core.CompileModel( ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); @@ -173,13 +175,19 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { } } -void BasicBackend::EnableCaching() { +void BasicBackend::EnableCaching(ov::AnyMap& device_config) { // cache_dir argument has no effect when working with an embed-mode EPContext Graph if (is_ep_ctx_graph_) return; - if (!global_context_.cache_dir.empty()) { + if (!global_context_.cache_dir.empty() && !global_context_.export_ep_ctx_blob) { LOGS_DEFAULT(INFO) << log_tag << "Enables Caching"; - global_context_.ie_core.SetCache(global_context_.cache_dir, global_context_.device_type); + if (global_context_.device_type.find("AUTO:GPU") != std::string::npos) { + std::pair device_property; + device_property = std::make_pair("CACHE_DIR", global_context_.cache_dir); + device_config.emplace(ov::device::properties("GPU", device_property)); + } else { + global_context_.ie_core.SetCache(global_context_.cache_dir); + } } } @@ -274,7 +282,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque } try { - infer_request->SetTensor(input_name, tensor_ptr); + infer_request->SetTensor(std::move(input_name), tensor_ptr); } catch (const char* msg) { ORT_THROW(msg); } diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 5565223f067b..bcd3161590ba 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -37,7 +37,7 @@ class BasicBackend : public IBackend { void PopulateCompiledDirectory(std::string, std::string&, std::string&, bool&); bool ValidateSubgraph(std::map>& const_outputs_map); void PopulateConfigValue(ov::AnyMap& device_config); - void EnableCaching(); + void EnableCaching(ov::AnyMap& device_config); void EnableGPUThrottling(ov::AnyMap& device_config); void EnableStreams(); void SetNumThreads(ov::AnyMap& device_config); diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index cd1ae6150e1d..e2df9c83f15a 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -19,8 +19,7 @@ Status EPCtxHandler::ExportEPCtxModel(const GraphViewer& graph_viewer, const logging::Logger& logger, const bool& ep_context_embed_mode, const std::string& model_blob_str, - const std::string& openvino_sdk_version, - const std::string& device_type) const { + const std::string& openvino_sdk_version) const { auto model_build = graph_viewer.CreateModel(logger); auto& graph_build = model_build->MainGraph(); @@ -77,9 +76,12 @@ Status EPCtxHandler::ExportEPCtxModel(const GraphViewer& graph_viewer, model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); // Finally, dump the model - std::ofstream dump(graph_name + "-ov_" + device_type + "_blob.onnx", - std::ios::out | std::ios::trunc | std::ios::binary); - model_proto->SerializeToOstream(dump); + std::ofstream epctx_onnx_model(graph_name, + std::ios::out | std::ios::trunc | std::ios::binary); + if (!epctx_onnx_model) { + ORT_THROW("Unable to create epctx onnx model file "); + } + model_proto->SerializeToOstream(epctx_onnx_model); LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Export blob as EPContext Node"; @@ -90,9 +92,7 @@ Status EPCtxHandler::ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer) { auto node = graph_viewer.GetNode(0); auto& attrs = node->GetAttributes(); ORT_ENFORCE(attrs.count(EP_CACHE_CONTEXT) > 0); - model_stream_ = std::make_shared(attrs.at(EP_CACHE_CONTEXT).s()); - LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node"; is_valid_ep_ctx_graph_ = true; diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h index b2b9b5bc53d4..610e9fd49c90 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h @@ -29,8 +29,7 @@ class EPCtxHandler { const logging::Logger& logger, const bool& ep_context_embed_mode, const std::string& model_blob_str, - const std::string& openvino_sdk_version, - const std::string& device_type) const; + const std::string& openvino_sdk_version) const; Status ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer); bool CheckForOVEPCtxNode(const GraphViewer& graph_viewer, std::string openvino_sdk_version) const; bool IsValidOVEPCtxGraph() const { return is_valid_ep_ctx_graph_; } diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 655e1b180388..29c45916795d 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -34,6 +34,7 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv global_context_->export_ep_ctx_blob = info.export_ep_ctx_blob_; global_context_->enable_qdq_optimizer = info.enable_qdq_optimizer_; global_context_->disable_cpu_fallback = info.disable_cpu_fallback_; + global_context_->ep_context_embed_mode = info.so_epctx_embed_mode_; // to check if target device is available // using ie_core capability GetAvailableDevices to fetch list of devices plugged in @@ -47,7 +48,7 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv info.device_type_.find("AUTO") != std::string::npos) { device_found = true; } else { - for (std::string device : available_devices) { + for (const std::string& device : available_devices) { if (device.rfind(info.device_type_, 0) == 0) { if (info.device_type_.find("GPU") != std::string::npos && (info.precision_ == "FP32" || info.precision_ == "FP16" || @@ -146,11 +147,6 @@ common::Status OpenVINOExecutionProvider::Compile( *GetLogger(), ep_ctx_handle_); - if (global_context_->export_ep_ctx_blob && !ep_ctx_handle_.IsValidOVEPCtxGraph()) { - ORT_RETURN_IF_ERROR(backend_manager->ExportCompiledBlobAsEPCtxNode(graph_body_viewer, - *GetLogger())); - } - compute_info.create_state_func = [backend_manager](ComputeContext* context, FunctionState* state) { OpenVINOEPFunctionState* p = new OpenVINOEPFunctionState(); diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index 050fb91c5177..030e5bba71b6 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -16,16 +16,23 @@ namespace onnxruntime { +struct OVDevices { + ov::Core core; + std::vector get_ov_devices() const { + return core.get_available_devices(); + } +}; + static void print_build_options() { std::cout << "[ERROR] INVALID DEVICE BUILD TYPE SPECIFIED" << std::endl; std::cout << "Specify the keyword HETERO (or) MULTI (or) AUTO followed by the devices in the order of priority " << "you want to build" << std::endl; std::cout << "The different hardware devices that can be added with HETERO/MULTI/AUTO build " - << "are ['CPU','GPU','NPU']" + << "are ['CPU','GPU','NPU','GPU.x'] where x = 0,1,2 and so on" << std::endl; std::cout << "An example of how to specify the HETERO or MULTI or AUTO build type. " - << "Ex: HETERO:GPU,CPU Ex: MULTI:GPU,CPU Ex: AUTO:GPU,CPU" + << "Ex: HETERO:GPU,CPU Ex: MULTI:GPU,CPU Ex: AUTO:GPU,CPU Ex: AUTO:GPU.0,CPU Ex: AUTO:GPU.1,CPU" << std::endl; } @@ -40,7 +47,8 @@ static std::vector split(const std::string& s, char delim) { return result; } -static std::vector parseDevices(const std::string& device_string) { +static std::vector parseDevices(const std::string& device_string, + const std::vector& available_devices) { std::string comma_separated_devices = device_string; if (comma_separated_devices.find(":") != std::string::npos) { comma_separated_devices = comma_separated_devices.substr(comma_separated_devices.find(":") + 1); @@ -50,8 +58,15 @@ static std::vector parseDevices(const std::string& device_string) { print_build_options(); ORT_THROW("Invalid device string: " + device_string); } - std::vector dev_options = {"CPU", "GPU", "NPU"}; - for (std::string dev : devices) { + std::set dev_options = {"CPU", "GPU", "NPU"}; + + for (auto& device : available_devices) { + if (dev_options.find(device) == dev_options.end()) { + auto dev_options_update = dev_options.emplace(device); + } + } + + for (const std::string& dev : devices) { if (!std::count(dev_options.begin(), dev_options.end(), dev)) { print_build_options(); ORT_THROW("Invalid device string: " + device_string); @@ -75,28 +90,42 @@ struct OpenVINOExecutionProviderInfo { bool export_ep_ctx_blob_{false}; bool enable_qdq_optimizer_{false}; bool disable_cpu_fallback_{false}; + bool so_epctx_embed_mode_{true}; OpenVINOExecutionProviderInfo() = delete; - explicit OpenVINOExecutionProviderInfo(std::string dev_type, std::string precision, bool enable_npu_fast_compile, - size_t num_of_threads, std::string cache_dir, std::string model_priority, + explicit OpenVINOExecutionProviderInfo(const std::string& dev_type, const std::string& precision, + bool enable_npu_fast_compile, size_t num_of_threads, + const std::string& cache_dir, const std::string& model_priority, int num_streams, void* context, bool enable_opencl_throttling, bool disable_dynamic_shapes, bool export_ep_ctx_blob, - bool enable_qdq_optimizer, bool disable_cpu_fallback) - : precision_(precision), + bool enable_qdq_optimizer, bool disable_cpu_fallback, + bool so_epctx_embed_mode) + : precision_(std::move(precision)), enable_npu_fast_compile_(enable_npu_fast_compile), num_of_threads_(num_of_threads), cache_dir_(std::move(cache_dir)), - model_priority_(model_priority), + model_priority_(std::move(model_priority)), num_streams_(num_streams), context_(context), enable_opencl_throttling_(enable_opencl_throttling), disable_dynamic_shapes_(disable_dynamic_shapes), export_ep_ctx_blob_(export_ep_ctx_blob), enable_qdq_optimizer_(enable_qdq_optimizer), - disable_cpu_fallback_(disable_cpu_fallback) { + disable_cpu_fallback_(disable_cpu_fallback), + so_epctx_embed_mode_{so_epctx_embed_mode} { std::set ov_supported_device_types = {"CPU", "GPU", "GPU.0", "GPU.1", "NPU"}; + + OVDevices devices; + std::vector available_devices = devices.get_ov_devices(); + + for (auto& device : available_devices) { + if (ov_supported_device_types.find(device) == ov_supported_device_types.end()) { + ov_supported_device_types.emplace(device); + } + } + if (dev_type == "") { LOGS_DEFAULT(INFO) << "[OpenVINO-EP]" << "No runtime device selection option provided."; @@ -116,7 +145,7 @@ struct OpenVINOExecutionProviderInfo { dev_type = DEVICE; if (dev_type.find("HETERO") == 0 || dev_type.find("MULTI") == 0 || dev_type.find("AUTO") == 0) { - std::vector devices = parseDevices(dev_type); + std::vector devices = parseDevices(dev_type, available_devices); precision_ = "FP16"; if (devices[0] == "CPU") { precision_ = "FP32"; @@ -127,7 +156,7 @@ struct OpenVINOExecutionProviderInfo { } else if (ov_supported_device_types.find(dev_type) != ov_supported_device_types.end()) { device_type_ = std::move(dev_type); } else if (dev_type.find("HETERO") == 0 || dev_type.find("MULTI") == 0 || dev_type.find("AUTO") == 0) { - std::vector devices = parseDevices(dev_type); + std::vector devices = parseDevices(dev_type, available_devices); device_type_ = dev_type; } else { ORT_THROW("Invalid device string: " + dev_type); diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 45bba431741c..3738f2a53415 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -14,7 +14,8 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { int num_streams, void* context, bool enable_opencl_throttling, bool disable_dynamic_shapes, bool export_ep_ctx_blob, bool enable_qdq_optimizer, - bool disable_cpu_fallback) + bool disable_cpu_fallback, + bool so_epctx_embed_mode) : precision_(precision), enable_npu_fast_compile_(enable_npu_fast_compile), num_of_threads_(num_of_threads), @@ -25,10 +26,12 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { disable_dynamic_shapes_(disable_dynamic_shapes), export_ep_ctx_blob_(export_ep_ctx_blob), enable_qdq_optimizer_(enable_qdq_optimizer), - disable_cpu_fallback_(disable_cpu_fallback) { + disable_cpu_fallback_(disable_cpu_fallback), + so_epctx_embed_mode_(so_epctx_embed_mode) { device_type_ = (device_type == nullptr) ? "" : device_type; cache_dir_ = (cache_dir == nullptr) ? "" : cache_dir; } + ~OpenVINOProviderFactory() override { } @@ -48,13 +51,15 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { bool export_ep_ctx_blob_; bool enable_qdq_optimizer_; bool disable_cpu_fallback_; + bool so_epctx_embed_mode_; }; std::unique_ptr OpenVINOProviderFactory::CreateProvider() { OpenVINOExecutionProviderInfo info(device_type_, precision_, enable_npu_fast_compile_, num_of_threads_, cache_dir_, model_priority_, num_streams_, context_, enable_opencl_throttling_, disable_dynamic_shapes_, export_ep_ctx_blob_, enable_qdq_optimizer_, - disable_cpu_fallback_); + disable_cpu_fallback_, + so_epctx_embed_mode_); return std::make_unique(info); } @@ -105,6 +110,8 @@ struct OpenVINO_Provider : Provider { bool disable_cpu_fallback = false; + bool so_epctx_embed_mode = true; + if (provider_options_map.find("device_type") != provider_options_map.end()) { device_type = provider_options_map.at("device_type").c_str(); @@ -113,6 +120,14 @@ struct OpenVINO_Provider : Provider { std::set deprecated_device_types = {"CPU_FP32", "GPU_FP32", "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", "GPU.0_FP16", "GPU.1_FP16"}; + OVDevices devices; + std::vector available_devices = devices.get_ov_devices(); + + for (auto& device : available_devices) { + if (ov_supported_device_types.find(device) == ov_supported_device_types.end()) { + ov_supported_device_types.emplace(device); + } + } if (deprecated_device_types.find(device_type) != deprecated_device_types.end()) { std::string deprecated_device = device_type; int delimit = device_type.find("_"); @@ -128,8 +143,8 @@ struct OpenVINO_Provider : Provider { (device_type.find("MULTI:") == 0) || (device_type.find("AUTO:") == 0))) { ORT_THROW( - "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. " - "Select from 'CPU', 'GPU', 'GPU.0', 'GPU.1', 'NPU' or from" + "[ERROR] [OpenVINO] You have selected wrong configuration value for the key 'device_type'. " + "Select from 'CPU', 'GPU', 'NPU', 'GPU.x' where x = 0,1,2 and so on or from" " HETERO/MULTI/AUTO options available. \n"); } } @@ -177,6 +192,10 @@ struct OpenVINO_Provider : Provider { } if (provider_options_map.find("num_of_threads") != provider_options_map.end()) { + if (!std::all_of(provider_options_map.at("num_of_threads").begin(), + provider_options_map.at("num_of_threads").end(), ::isdigit)) { + ORT_THROW("[ERROR] [OpenVINO-EP] Number of threads should be a number. \n"); + } num_of_threads = std::stoi(provider_options_map.at("num_of_threads")); if (num_of_threads <= 0) { num_of_threads = 1; @@ -253,9 +272,8 @@ struct OpenVINO_Provider : Provider { } } } - - if (provider_options_map.find("export_ep_ctx_blob") != provider_options_map.end()) { - bool_flag = provider_options_map.at("export_ep_ctx_blob"); + if (provider_options_map.find("so_export_ep_ctx_blob") != provider_options_map.end()) { + bool_flag = provider_options_map.at("so_export_ep_ctx_blob"); if (bool_flag == "true" || bool_flag == "True") export_ep_ctx_blob = true; else if (bool_flag == "false" || bool_flag == "False") @@ -271,6 +289,37 @@ struct OpenVINO_Provider : Provider { disable_cpu_fallback = false; bool_flag = ""; } + if (provider_options_map.find("so_epctx_embed_mode") != provider_options_map.end()) { + bool_flag = provider_options_map.at("so_epctx_embed_mode"); + if (bool_flag == "true" || bool_flag == "True") + so_epctx_embed_mode = true; + else if (bool_flag == "false" || bool_flag == "False") + so_epctx_embed_mode = false; + bool_flag = ""; + } + + if (provider_options_map.find("so_epctx_path") != provider_options_map.end()) { + // The path to dump epctx model is valid only when epctx is enabled. + // Overrides the cache_dir option to dump model cache files from OV. + if (export_ep_ctx_blob) { + auto ep_context_file_path_ = provider_options_map.at("so_epctx_path"); + auto file_path = std::filesystem::path(ep_context_file_path_); + // ep_context_file_path_ file extension must be .onnx + if (!ep_context_file_path_.empty() && + file_path.extension().generic_string() == ".onnx") { + // ep_context_file_path_ must be provided as a directory, create it if doesn't exist + auto parent_path = file_path.parent_path(); + if (!std::filesystem::is_directory(parent_path) && + !std::filesystem::create_directory(parent_path)) { + ORT_THROW("[ERROR] [OpenVINO] Failed to create directory : " + file_path.parent_path().generic_string() + " \n"); + } + cache_dir = ep_context_file_path_.c_str(); + } else { + ORT_THROW("[ERROR] [OpenVINO] Invalid ep_ctx_file_path" + ep_context_file_path_ + " \n"); + } + } + } + return std::make_shared(const_cast(device_type.c_str()), const_cast(precision.c_str()), enable_npu_fast_compile, @@ -283,7 +332,8 @@ struct OpenVINO_Provider : Provider { disable_dynamic_shapes, export_ep_ctx_blob, enable_qdq_optimizer, - disable_cpu_fallback); + disable_cpu_fallback, + so_epctx_embed_mode); } void Initialize() override { diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 8dd00857b7dd..7e8681d304ab 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -63,7 +63,6 @@ std::shared_ptr OVCore::ReadModel(const std::string& model, const std return FE->convert(inputModel); } else { ORT_THROW(log_tag + "[OpenVINO-EP] Unknown exception while Reading network"); - return NULL; } } catch (const Exception& e) { ORT_THROW(log_tag + "[OpenVINO-EP] Exception while Reading network: " + std::string(e.what())); @@ -73,9 +72,9 @@ std::shared_ptr OVCore::ReadModel(const std::string& model, const std } OVExeNetwork OVCore::CompileModel(std::shared_ptr& ie_cnn_network, - std::string hw_target, - const ov::AnyMap& device_config, - std::string name) { + std::string& hw_target, + ov::AnyMap& device_config, + const std::string& name) { ov::CompiledModel obj; try { obj = oe.compile_model(ie_cnn_network, hw_target, device_config); @@ -92,22 +91,12 @@ OVExeNetwork OVCore::CompileModel(std::shared_ptr& ie_cnn_netwo } OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, - std::string hw_target, - std::string precision, - std::string cache_dir, - const ov::AnyMap& device_config, - std::string name) { + std::string& hw_target, + ov::AnyMap& device_config, + const std::string& name) { ov::CompiledModel obj; try { - if (hw_target == "AUTO:GPU,CPU") { - obj = oe.compile_model(onnx_model, ov::Tensor(), - "AUTO", - ov::device::priorities("GPU", "CPU"), - ov::device::properties("GPU", {ov::cache_dir(cache_dir), - ov::hint::inference_precision(precision)})); - } else { - obj = oe.compile_model(onnx_model, ov::Tensor(), hw_target, device_config); - } + obj = oe.compile_model(onnx_model, ov::Tensor(), hw_target, device_config); #ifndef NDEBUG printDebugInfo(obj); #endif @@ -123,9 +112,19 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, OVExeNetwork OVCore::ImportModel(std::shared_ptr model_stream, std::string hw_target, const ov::AnyMap& device_config, + bool embed_mode, std::string name) { try { - auto obj = oe.import_model(*model_stream, hw_target, device_config); + ov::CompiledModel obj; + if (embed_mode) { + obj = oe.import_model(*model_stream, hw_target, device_config); + } else { + std::string blob_file_path = (*model_stream).str(); + std::ifstream modelStream(blob_file_path, std::ios_base::binary | std::ios_base::in); + obj = oe.import_model(modelStream, + hw_target, + {}); + } #ifndef NDEBUG printDebugInfo(obj); #endif @@ -138,10 +137,8 @@ OVExeNetwork OVCore::ImportModel(std::shared_ptr model_strea } } -void OVCore::SetCache(std::string cache_dir_path, std::string device_type) { - if (device_type != "AUTO:GPU,CPU") { - oe.set_property(ov::cache_dir(cache_dir_path)); - } +void OVCore::SetCache(const std::string& cache_dir_path) { + oe.set_property(ov::cache_dir(cache_dir_path)); } #ifdef IO_BUFFER_ENABLED diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index af6f252feb2c..fa22e0f3cb03 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -40,20 +40,23 @@ class OVCore { ov::Core oe; public: + // OV Interface For Reading Model std::shared_ptr ReadModel(const std::string& model_stream, const std::string& model_path) const; + // OV Interface for Compiling OV Model Type OVExeNetwork CompileModel(std::shared_ptr& ie_cnn_network, - std::string hw_target, - const ov::AnyMap& device_config, - std::string name); + std::string& hw_target, + ov::AnyMap& device_config, + const std::string& name); + // OV Interface for Fast Compile OVExeNetwork CompileModel(const std::string& onnx_model, - std::string hw_target, - std::string precision, - std::string cache_dir, - const ov::AnyMap& device_config, - std::string name); + std::string& hw_target, + ov::AnyMap& device_config, + const std::string& name); + // OV Interface for Import model Stream OVExeNetwork ImportModel(std::shared_ptr model_stream, std::string hw_target, const ov::AnyMap& device_config, + bool embed_mode, std::string name); #ifdef IO_BUFFER_ENABLED OVExeNetwork CompileModel(std::shared_ptr& model, @@ -64,7 +67,7 @@ class OVCore { std::string name); #endif std::vector GetAvailableDevices(); - void SetCache(std::string cache_dir_path, std::string device_type); + void SetCache(const std::string& cache_dir_path); ov::Core& Get() { return oe; } void SetStreams(const std::string& device_type, int num_streams); }; diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 856b97a0896d..3fcaff4369c8 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -35,18 +35,16 @@ GetCapability::GetCapability(const GraphViewer& graph_viewer_param, device_type_ = "CPU"; if (enable_qdq_optimizer) npu_qdq_optimizer_enabled = true; } -#if OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 1 - data_ops_ = new DataOps(graph_viewer_, V_2023_1, device_type_, npu_qdq_optimizer_enabled); -#elif OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 2 - data_ops_ = new DataOps(graph_viewer_, V_2023_2, device_type_, npu_qdq_optimizer_enabled); -#elif OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 3 - data_ops_ = new DataOps(graph_viewer_, V_2023_3, device_type_, npu_qdq_optimizer_enabled); -#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 0 +#if OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 0 data_ops_ = new DataOps(graph_viewer_, V_2024_0, device_type_, npu_qdq_optimizer_enabled); #elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 1 data_ops_ = new DataOps(graph_viewer_, V_2024_1, device_type_, npu_qdq_optimizer_enabled); +#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 2 + data_ops_ = new DataOps(graph_viewer_, V_2024_2, device_type_, npu_qdq_optimizer_enabled); +#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 3 + data_ops_ = new DataOps(graph_viewer_, V_2024_3, device_type_, npu_qdq_optimizer_enabled); #else - data_ops_ = new DataOps(graph_viewer_, V_2024_1, device_type_, npu_qdq_optimizer_enabled); + data_ops_ = new DataOps(graph_viewer_, V_2024_3, device_type_, npu_qdq_optimizer_enabled); #endif } diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 38c029faff9d..d9aa13ec1bba 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -142,6 +142,7 @@ std::vector supported_op_mode = { {"GreaterOrEqual", V_2022_1, {"CPU", "GPU"}}, {"GridSample", V_2022_3, {"CPU"}}, {"GridSample", V_2023_0, {"GPU"}}, + {"GRU", V_2024_1, {"CPU", "GPU"}}, {"HardMax", V_2023_1, {"CPU", "GPU"}}, {"Identity", V_2020_4, {"CPU", "GPU"}}, {"If", V_2022_3, {"CPU", "GPU"}}, @@ -155,6 +156,7 @@ std::vector supported_op_mode = { {"LessOrEqual", V_2022_1, {"CPU", "GPU"}}, {"Log", V_2020_4, {"CPU", "GPU"}}, {"LogSoftMax", V_2022_1, {"CPU", "GPU"}}, + {"LogSoftmax", V_2024_1, {"CPU", "GPU"}}, {"Loop", V_2021_4, {"CPU", "GPU"}}, {"LpNormalization", V_2023_1, {"CPU", "GPU"}}, {"LRN", V_2020_4, {"CPU", "GPU"}}, @@ -361,7 +363,7 @@ void DataOps::populate_op_mode_supported() { // populate unsupportedmode_t { - UnsupportedOpMode obj = {{V_2024_1}, + UnsupportedOpMode obj = {{V_2024_1, V_2024_2, V_2024_3}, [this](const Node* node, const InitializedTensorSet&) { // If the Input of ReduceMax op is UINT8, it is rejected (Due to output mismatch) for (size_t i = 0; i < node->InputDefs().size(); i++) { @@ -376,7 +378,7 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"ReduceMax", obj}); } { - UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1}, + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3}, [this](const Node* node, const InitializedTensorSet&) { const auto& input_arg = node->InputDefs()[1]; auto shape = input_arg->Shape(); @@ -393,7 +395,7 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Reshape", obj}); } { - UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1}, + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3}, [this](const Node* node, const InitializedTensorSet&) { // If the operator is unsqueeze // If axes is an input, then we cannot produce a static graph. @@ -408,7 +410,7 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Unsqueeze", obj}); } { - UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1}, + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3}, [this](const Node* node, const InitializedTensorSet&) { // check for attributes auto& upsample_attr = node->GetAttributes(); diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h index 7cfb0516b8cc..4c064b08405c 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h @@ -28,7 +28,9 @@ enum versionNum { V_2023_2, V_2023_3, V_2024_0, - V_2024_1 + V_2024_1, + V_2024_2, + V_2024_3 }; using VersionNum = enum versionNum; diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index c7689a0be7e7..a2b3ed068235 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -205,11 +205,11 @@ static bool IsConnectedQAConstantInitializer(const Node* dq_node, const onnxrunt // Check required because in some cases, when a NodeUnit cannot be formed with this standalone DQ // we still need to check if it feeds into a supported Op -static bool DQFeedsASupportedOp(const Node* dq_node, const onnxruntime::GraphViewer& src_graph) { +static bool DQFeedsASupportedOp(const Node* dq_node) { if (!dq_node->GetOutputEdgesCount()) return false; // Only feeds the graph output, and not any node const auto& target_node = *dq_node->OutputNodesBegin(); - const auto op_type = target_node.OpType(); + const auto& op_type = target_node.OpType(); if (op_type == "Conv" || op_type == "MatMul") { // Conv and MatMul always keeps int8 DQs except if the DQ is sandwiched between Softmax and Conv/MatMul @@ -219,8 +219,8 @@ static bool DQFeedsASupportedOp(const Node* dq_node, const onnxruntime::GraphVie return true; } } else if (op_type == "Add") { - // Add keeps all DQs except if it has const inits - return !IsAnyDQAConstantInitializer(&target_node, src_graph); + // Add => keeps all DQs + return true; } return false; } @@ -291,7 +291,7 @@ static bool CheckDQRuleSet(const NodeUnit& node_unit, const onnxruntime::GraphViewer& src_graph, SkipReason& reason) { const auto& target_node = node_unit.GetNode(); - auto op_type = node_unit.OpType(); + const auto& op_type = node_unit.OpType(); // #1 Reverse DQ duplication if (dq_node->Name().find(DuplicateDQ) != std::string::npos) { @@ -337,6 +337,18 @@ static bool CheckDQRuleSet(const NodeUnit& node_unit, } } +static bool CheckQFeedsIntoQuantizedOutput(const NodeUnit& node_unit, + const std::unordered_map graph_op_data_type) { + auto op_of_quantized_layer = node_unit.Outputs(); + for (auto& itr : op_of_quantized_layer) { + auto it = graph_op_data_type.find(itr.node_arg.Name()); + if (it != graph_op_data_type.end() && it->second == "tensor(uint8)") { + return true; + } + } + return false; +} + static bool CheckQRuleSet(const NodeUnit& node_unit, const Node* q_node, const onnxruntime::GraphViewer& src_graph, @@ -345,7 +357,13 @@ static bool CheckQRuleSet(const NodeUnit& node_unit, // This Q should also be uint8 const auto& target_node = node_unit.GetNode(); - auto op_type = node_unit.OpType(); + const auto& op_type = node_unit.OpType(); + + auto op = src_graph.GetOutputs(); + std::unordered_map graph_op_data_type; + for (auto& ops : op) { + graph_op_data_type[src_graph.GetNodeArg(ops->Name())->Name()] = ops->Type()->data(); + } // If UInt16 Q, don't keep it if (GetQDQDataType(q_node) == DT_UINT16 || GetQDQDataType(q_node) == DT_INT16) { @@ -359,6 +377,8 @@ static bool CheckQRuleSet(const NodeUnit& node_unit, } else if (op_type == "Add") { // Add keeps all Qs return true; + } else if (CheckQFeedsIntoQuantizedOutput(node_unit, std::move(graph_op_data_type))) { + return true; } else { // Keep Q of an unsupported Op only if the target that succeeds it is a supported Op in this list return IsNextTargetNodeOfQValid(q_node, &target_node, src_graph, {"Conv", "Add", "MatMul"}, false); @@ -469,7 +489,7 @@ static void AddStandaloneNodeUnit(onnxruntime::Graph& dst_graph, const onnxrunti add_identity_op(true); else if (IsConnectedQPresent(src_graph, dst_graph.Nodes(), &node_unit.GetNode(), node_unit.GetNode().InputDefs())) AddNode(initializers_to_keep, src_graph, dst_graph, node_unit.GetNode()); - else if (DQFeedsASupportedOp(&node_unit.GetNode(), src_graph)) + else if (DQFeedsASupportedOp(&node_unit.GetNode())) AddNode(initializers_to_keep, src_graph, dst_graph, node_unit.GetNode()); else add_identity_op(false); @@ -543,7 +563,7 @@ static void AddQDQNodeUnit(onnxruntime::Graph& dst_graph, // Add Node args for inputs for (const auto& node_unit_input : node_unit_inputs) { - auto node_arg_name = node_unit_input.node_arg.Name(); + const auto& node_arg_name = node_unit_input.node_arg.Name(); if (auto dq_node_arg = dq_node_args_to_keep.find(node_arg_name); dq_node_arg != dq_node_args_to_keep.end()) { // Add supported DQ as an input arg for the target node input_args.push_back(dq_node_arg->second); diff --git a/onnxruntime/core/providers/partitioning_utils.cc b/onnxruntime/core/providers/partitioning_utils.cc index c45f5cd0848d..83c08f3dbd25 100644 --- a/onnxruntime/core/providers/partitioning_utils.cc +++ b/onnxruntime/core/providers/partitioning_utils.cc @@ -88,8 +88,6 @@ It is required to ensure we do not break up a QDQ node unit during partitioning. @param graph_viewer GraphViewer that IExecutionProvider::GetCapability is called with. @param is_node_supported_fn Callback to check whether a node is supported. @param on_group_closed_fn Callback to indicate a completed partition node group. -@param debug_output Print diagnostic output about the partitions and reasons for partition breaks. - No-op in a release build. @return The partition node groups. */ std::vector> CreateSupportedPartitionNodeGroups( @@ -97,12 +95,7 @@ std::vector> CreateSupportedPartitionNodeGroups( const IsNodeSupportedFn& is_node_supported_fn, const OnGroupClosedFn& on_group_closed_fn, const std::string& execution_provider_type, - const std::unordered_map* node_unit_map, - bool debug_output) { -#ifdef NDEBUG - ORT_UNUSED_PARAMETER(debug_output); -#endif - + const std::unordered_map* node_unit_map) { ORT_ENFORCE(is_node_supported_fn, "Node support test is required."); /* @@ -146,12 +139,10 @@ std::vector> CreateSupportedPartitionNodeGroups( auto close_group = [&]() { if (!supported_group.empty()) { #ifndef NDEBUG - if (debug_output) { - LOGS_DEFAULT(VERBOSE) << "New partition node group.\n" - << "Unsupported nodes on group border: " - << NodeGroupDebugString(nodes_to_process_with_next_group, true) << "\n" - << "Nodes in group: " << NodeGroupDebugString(supported_group); - } + LOGS_DEFAULT(VERBOSE) << "New partition node group.\n" + << "Unsupported nodes on group border: " + << NodeGroupDebugString(nodes_to_process_with_next_group, true) << "\n" + << "Nodes in group: " << NodeGroupDebugString(supported_group); #endif // if no on_group_closed_fn callback was given, keep the partition @@ -163,7 +154,7 @@ std::vector> CreateSupportedPartitionNodeGroups( } #ifndef NDEBUG else { - LOGS_DEFAULT_IF(debug_output, VERBOSE) << "Discarded partition node group."; + LOGS_DEFAULT(VERBOSE) << "Discarded partition node group."; } #endif @@ -291,7 +282,8 @@ InlinedHashSet CreateExcludedNodeSet(const GraphViewer& graph_viewe std::unique_ptr MakeComputeCapability(const GraphViewer& graph_viewer, const std::vector& group, const GenerateMetadefNameFn& generate_metadef_name, - const std::string& execution_provider_name) { + const std::string& execution_provider_name, + bool drop_constant_initializers) { std::unordered_set node_set; node_set.reserve(group.size()); node_set.insert(group.cbegin(), group.cend()); @@ -354,6 +346,10 @@ std::unique_ptr MakeComputeCapability(const GraphViewer& grap meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; for (const auto& input : ordered_subgraph_inputs) { + if (drop_constant_initializers && graph_viewer.IsConstantInitializer(input->Name(), true)) { + continue; + } + meta_def->inputs.push_back(input->Name()); } @@ -374,13 +370,12 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const std::string& execution_provider_name, const std::string& execution_provider_type, const std::unordered_map* node_unit_map, - bool debug_output) { + bool drop_constant_initializers) { const auto groups = CreateSupportedPartitionNodeGroups(graph_viewer, is_node_supported_fn, on_partition_closed_fn, execution_provider_type, - node_unit_map, - debug_output); + node_unit_map); std::vector> partitions{}; partitions.reserve(groups.size()); @@ -390,7 +385,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, std::back_inserter(partitions), [&](const auto& supported_partition) { return MakeComputeCapability(graph_viewer, supported_partition, generate_metadef_name_fn, - execution_provider_name); + execution_provider_name, drop_constant_initializers); }); return partitions; @@ -404,7 +399,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const std::string& execution_provider_name, const std::string& execution_provider_type, const std::unordered_map* node_unit_map, - bool debug_output) { + bool drop_constant_initializers) { const auto excluded_nodes = CreateExcludedNodeSet(graph_viewer, stop_ops); const bool check_excluded_nodes = !excluded_nodes.empty(); @@ -419,7 +414,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, execution_provider_name, execution_provider_type, node_unit_map, - debug_output); + drop_constant_initializers); } } // namespace utils diff --git a/onnxruntime/core/providers/partitioning_utils.h b/onnxruntime/core/providers/partitioning_utils.h index c3f6b104e3f6..235a88cfdb8a 100644 --- a/onnxruntime/core/providers/partitioning_utils.h +++ b/onnxruntime/core/providers/partitioning_utils.h @@ -62,9 +62,10 @@ Create the supported partitions for the execution provider. @param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance. @param node_unit_map Map of each Node in the graph_viewer to its NodeUnit. Provide if EP handles QDQ format models. Should be created by EP calling GetAllNodeUnits. -@param debug_output Print diagnostic output about the partitions and reasons for partition breaks. - No-op in a release build. - +@param drop_constant_initializer Drop constant initializers from input to a ComputeCapability. + Set to true if constant initializers have been copied into a compiled model to allow + ORT to free the initializer. If the initializer remains as an input it will appear to + still be in-use. @returns ComputeCapability instances for all partitions assigned to the execution provider. */ std::vector> @@ -74,8 +75,8 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const GenerateMetadefNameFn& generate_metadef_name_fn, const std::string& execution_provider_name, const std::string& execution_provider_type, - const std::unordered_map* node_unit_map = nullptr, - bool debug_output = false); + const std::unordered_map* node_unit_map, + bool drop_constant_initializers = false); /** Create the supported partitions for the execution provider. @@ -88,9 +89,10 @@ Create the supported partitions for the execution provider. @param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance. @param node_unit_map Map of each Node in the graph_viewer to its NodeUnit. Provide if EP handles QDQ format models. Should be created by EP calling GetAllNodeUnits. -@param debug_output Print diagnostic output about the partitions and reasons for partition breaks. - No-op in a release build. - +@param drop_constant_initializer Drop constant initializers from input to a ComputeCapability. + Set to true if constant initializers have been copied into a compiled model to allow + ORT to free the initializer. If the initializer remains as an input it will appear to + still be in-use. @returns ComputeCapability instances for all partitions assigned to the execution provider. */ std::vector> @@ -100,8 +102,8 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const GenerateMetadefNameFn& generate_metadef_name, const std::string& execution_provider_name, const std::string& execution_provider_type, - const std::unordered_map* node_unit_map = nullptr, - bool debug_output = false); + const std::unordered_map* node_unit_map, + bool drop_constant_initializers = false); /** Create a ComputeCapability instance from the group of nodes. @@ -120,7 +122,8 @@ Will automatically determine the inputs and outputs required. std::unique_ptr MakeComputeCapability(const GraphViewer& graph_viewer, const std::vector& group, const GenerateMetadefNameFn& generate_metadef_name, - const std::string& execution_provider_name); + const std::string& execution_provider_name, + bool drop_constant_initializers); /** Create the set of nodes to exclude based on a set of stop ops. diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc index 16a058854a74..07abcf1c7bf8 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc @@ -392,15 +392,23 @@ class BatchNormOpBuilder : public BaseOpBuilder { const double rmin, QnnQuantParamsWrapper& quant_param, std::vector& raw_tensor) const { + bool symmetric = false; if (info.quant_param.IsQuantized()) { - raw_tensor.resize(double_tensor.size()); + size_t data_size = double_tensor.size(); + // QNN BatchNorm int32 bias requires symmetric quantizated + if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_32) { + data_size *= sizeof(int32_t); + symmetric = true; + } + raw_tensor.resize(data_size); float scale = 0.0f; - int zero_point = 0; + int32_t zero_point = 0; ORT_RETURN_IF_ERROR(utils::GetQuantParams(static_cast(rmin), static_cast(rmax), info.qnn_data_type, scale, - zero_point)); + zero_point, + symmetric)); quant_param = QnnQuantParamsWrapper(scale, zero_point); for (size_t i = 0; i < double_tensor.size(); ++i) { // onnx only supports 8 bits quantization @@ -411,6 +419,10 @@ class BatchNormOpBuilder : public BaseOpBuilder { } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { int8_t quant_value = static_cast(quant_value_int); raw_tensor[i] = *reinterpret_cast(&quant_value); + } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_32) { + int32_t quant_value = static_cast(quant_value_int); + size_t pos = i * sizeof(int32_t); + std::memcpy(&raw_tensor[pos], reinterpret_cast(&quant_value), sizeof(int32_t)); } else { // TODO(adrianlizarraga): Should support 16-bit quantization as well. ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", info.qnn_data_type); @@ -444,8 +456,7 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape of input 0."); const size_t input_rank = input_shape.size(); - ORT_RETURN_IF(input_rank <= 2 || input_rank > 4, - "QNN BatchNorm only supports input ranks of size 3 or 4."); + ORT_RETURN_IF(input_rank > 4, "QNN BatchNorm only supports input ranks of size <= 4."); const uint32_t num_channels = input_shape[1]; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc index d0f6ce9effd9..64f676aaa987 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc @@ -79,7 +79,7 @@ Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, if (is_quantized_tensor) { ORT_RETURN_IF_ERROR(utils::GetQnnDataType(true, type_proto, qnn_data_type)); float scale = 0.0f; - int zero_point = 0; + int32_t zero_point = 0; float rmax = 1.0f; float rmin = 1.0f; ORT_RETURN_IF_ERROR(utils::GetQuantParams(rmin, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc index c667aeeaa61f..a31b15948cb7 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc @@ -87,10 +87,10 @@ Status LayerNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[BIAS_IDX], logger, input_names)); } -#if QNN_API_VERSION_MAJOR == 2 && QNN_API_VERSION_MINOR == 17 +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR == 17 || QNN_API_VERSION_MINOR == 18) if (!has_bias_input && IsNpuBackend(qnn_model_wrapper.GetQnnBackendType())) { - // Bias is implicit. QNN SDK 2.24 (QNN API version 2.17) has a validation bug for implicit bias inputs, so provide - // an explicit bias of all 0 (quantized int32). + // Bias is implicit. QNN SDK 2.24/2.25 (QNN API version 2.17/2.18) has a validation bug for implicit bias inputs, + // so provide an explicit bias of all 0 (quantized int32). TensorInfo x_input_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[X_IDX], x_input_info)); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc deleted file mode 100644 index b04075f11203..000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/qnn/builder/qnn_fusions.h" - -#include -#include -#include -#include -#include -#include -#include "core/graph/graph_utils.h" -#include "core/optimizer/qdq_transformer/qdq_util.h" -#include "core/framework/node_unit.h" -#include "core/providers/qnn/builder/qnn_utils.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" -#include "core/providers/qnn/builder/op_builder_factory.h" - -#define QNN_RETURN_OK_IF_ERROR(expr, logger) \ - do { \ - auto _status = (expr); \ - if ((!_status.IsOK())) { \ - LOGS((logger), VERBOSE) << _status.ErrorMessage(); \ - return Status::OK(); \ - } \ - } while (0) - -namespace onnxruntime { -namespace qnn { - -/** - * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from - * one quantization type (e.g., uint8_t) to another (e.g., uint16_t). - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion is not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param start_node_unit The node unit that could potentially start the DQ -> Q sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return An onnxruntime::Status - */ -static Status TryHandleConvertSequence(std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& start_node_unit, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool do_op_validation) { - const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - - // Looking for a standalone DQ to start the sequence. - if (start_node_unit.OpType() != QDQ::DQOpName || start_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - const Node& dq_node = start_node_unit.GetNode(); - - // DQ must have a single Q child. DQ must not produce a graph output. - auto children = graph_utils::FindChildrenByType(dq_node, QDQ::QOpName); - if (children.size() != 1 || dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { - return Status::OK(); - } - - const Node& q_node = *children[0]; - const auto q_node_unit_it = node_unit_map.find(&q_node); - - ORT_RETURN_IF(q_node_unit_it == node_unit_map.end(), "Node does not have a corresponding NodeUnit"); - - const NodeUnit* q_node_unit = q_node_unit_it->second; - - // Check if Q node has already been handled. Should not be the case if this - // fusion function has been called in topological order, but check to be safe. - if (handled_node_units.count(q_node_unit) != 0) { - return Status::OK(); - } - - // Q child must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { - return graph_viewer.GetConstantInitializer(initializer_name, true); - }; - - // DQ and Q must have equal scale type and different zp type. - if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { - return Status::OK(); - } - - const auto& node_name = utils::GetNodeName(start_node_unit); - const NodeUnitIODef& input_def = start_node_unit.Inputs()[0]; - const NodeUnitIODef& output_def = q_node_unit->Outputs()[0]; - - QnnTensorWrapper input_tensor; - QnnTensorWrapper output_tensor; - - // Run QNN validation on the final fused node before committing to doing a fusion. - // Importantly, this validation process does not modify the qnn_model_wrapper. - // If validation fails here, we return Status::OK() to allow QNN EP to use the normal OpBuilder workflow. - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_CONVERT, - {input_tensor.GetQnnTensor()}, - {output_tensor.GetQnnTensor()}, - {}), - logger); - - // Validation passed, so we're now committed to doing a fusion. The following statements modify qnn_model_wrapper. - // If we encounter an error, we return it directly to caller. - LOGS(logger, VERBOSE) << " Adding QNN Convert via fusion. dq_node name: [" << dq_node.Name() - << "] dq_node optype: [" << dq_node.OpType() - << "] q_node name: [" << q_node_unit->Name() - << "] q_node optype: [" << q_node_unit->OpType() - << "]"; - - // Add a QNN Convert to the model. Get the input from the DQ node, and the output from the Q node. - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(*q_node_unit), - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_CONVERT, - {input_def.node_arg.Name()}, - {output_def.node_arg.Name()}, - {}, - do_op_validation), - "Failed to add fused Convert node."); - - fused_nodes.push_back(&start_node_unit); - fused_nodes.push_back(q_node_unit); - - return Status::OK(); -} - -/** - * Tries to fuse the sequence `x * HardSigmoid(x)` into a single HardSwish(x) operator. - * Should be called in a topologically ordered iteration of node units. - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion was not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param starting_node The node unit that could potentially start the sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return A Status indicating a potential failure. - */ -static Status TryHandleHardSigmoidSequence(std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& start_node_unit, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool do_op_validation) { - // Looking for a standalone HardSigmoid to start the sequence. - if (start_node_unit.OpType() != "HardSigmoid" || start_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - NodeAttrHelper hs_attr_helper(start_node_unit); - float alpha = hs_attr_helper.Get("alpha", 0.2f); - float beta = hs_attr_helper.Get("beta", 0.5f); - constexpr float req_alpha = 1.0f / 6.0f; - constexpr float req_beta = 0.5f; - constexpr float alpha_eps = std::numeric_limits::epsilon() * req_alpha; - constexpr float beta_eps = std::numeric_limits::epsilon() * req_beta; - - // Check for explicit values of alpha and beta. - if (std::abs(alpha - req_alpha) > alpha_eps || std::abs(beta - req_beta) > beta_eps) { - return Status::OK(); - } - - const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - const Node& hs_node = start_node_unit.GetNode(); - - // HardSigmoid must have a single Mul child. HardSigmoid must not produce a graph output. - auto children = graph_utils::FindChildrenByType(hs_node, "Mul"); - if (children.size() != 1 || hs_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(hs_node)) { - return Status::OK(); - } - - const Node& mul_node = *children[0]; - const auto mul_node_unit_it = node_unit_map.find(&mul_node); - ORT_RETURN_IF(mul_node_unit_it == node_unit_map.end(), "Node does not have a corresponding NodeUnit"); - const NodeUnit* mul_node_unit = mul_node_unit_it->second; - - // Check if Mul node has already been handled. Should not be the case if this - // fusion function has been called in topological order, but check to be safe. - if (handled_node_units.count(mul_node_unit) != 0) { - return Status::OK(); - } - - // Mul child must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (mul_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - // Input to HardSigmoid must also be the other input to the Mul. - auto& hs_input_name = start_node_unit.Inputs()[0].node_arg.Name(); - const bool same_root_input = mul_node.InputDefs()[0]->Name() == hs_input_name || - mul_node.InputDefs()[1]->Name() == hs_input_name; - - if (!same_root_input) { - return Status::OK(); - } - - const auto& node_name = utils::GetNodeName(start_node_unit); - const NodeUnitIODef& input_def = start_node_unit.Inputs()[0]; - const NodeUnitIODef& output_def = mul_node_unit->Outputs()[0]; - - QnnTensorWrapper input_tensor; - QnnTensorWrapper output_tensor; - - // Run QNN validation on the final fused node before committing to doing a fusion. - // Importantly, this validation process does not modify the qnn_model_wrapper. - // If validation fails here, we return Status::OK() to allow QNN EP to use the normal OpBuilder workflow. - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_HARD_SWISH, - {input_tensor.GetQnnTensor()}, - {output_tensor.GetQnnTensor()}, - {}), - logger); - - // Validation passed, so we're now committed to doing a fusion. The following statements modify qnn_model_wrapper. - // If we encounter an error, we return it directly to caller. - LOGS(logger, VERBOSE) << " Adding QNN HardSwish via fusion. HardSigmoid name: [" << start_node_unit.Name() - << "] Mul name: [" << mul_node_unit->Name() << "]"; - - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_HARD_SWISH, - {input_def.node_arg.Name()}, - {output_def.node_arg.Name()}, - {}, - do_op_validation), - "Failed to add fused HardSwish node."); - - fused_nodes.push_back(&start_node_unit); - fused_nodes.push_back(mul_node_unit); - - return Status::OK(); -} - -using FusionFunc = Status (*)(std::vector&, - QnnModelWrapper&, - const NodeUnit&, - const std::unordered_map&, - const std::unordered_set&, - const logging::Logger&, - bool); - -Status TryFusions(/*out*/ std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& starting_node, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool validate) { - // Maps a starting operator type to the fusion function. - static std::unordered_map fusions = { - {"DequantizeLinear", TryHandleConvertSequence}, - {"HardSigmoid", TryHandleHardSigmoidSequence}, - }; - - // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). - if (starting_node.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - auto iter = fusions.find(starting_node.OpType()); - if (iter != fusions.end()) { - fused_nodes.clear(); - - FusionFunc fusion_func = iter->second; - ORT_RETURN_IF_ERROR(fusion_func(fused_nodes, qnn_model_wrapper, starting_node, node_unit_map, - handled_node_units, logger, validate)); - } - - return Status::OK(); -} - -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.h b/onnxruntime/core/providers/qnn/builder/qnn_fusions.h deleted file mode 100644 index 39e2e71c01d8..000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "core/framework/node_unit.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" - -namespace onnxruntime { -namespace qnn { - -/** - * Tries to fuse a node sequence starting from the given starting node. Should be called in a topologically ordered - * walk of node units. - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion was not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param starting_node The node unit that could potentially start the sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return A Status indicating a potential failure. - */ -Status TryFusions(/*out*/ std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& starting_node, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool do_op_validation); -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 503943dfb636..83f9184d3361 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -7,7 +7,7 @@ #include "QnnOpDef.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/qnn/builder/qnn_fusions.h" +#include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/utils.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" @@ -117,49 +117,20 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper."); } - std::unordered_set handled_node_units; + std::vector> qnn_node_groups; + qnn_node_groups.reserve(node_unit_holder.size()); - // Op builer - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (size_t i = 0; i < node_indices.size(); i++) { - const auto* node(graph_viewer.GetNode(node_indices[i])); + ORT_RETURN_IF_ERROR(qnn::GetQnnNodeGroups(qnn_node_groups, qnn_model_wrapper, node_unit_map, + node_unit_holder.size(), logger_)); - // Check whether it's part of NodeUnit - const NodeUnit& node_unit = GetNodeUnit(node, node_unit_map); - // Q, DQ nodes in the node unit only carry the quantization parameters - // Add the QNN node when it is the target node (It's a normal node or a single Q/DQ node) - const std::string& op_type = node_unit.OpType(); + for (const std::unique_ptr& qnn_node_group : qnn_node_groups) { + Status status = qnn_node_group->AddToModelBuilder(qnn_model_wrapper, logger_); - if (node != &node_unit.GetNode()) { - continue; + if (!status.IsOK()) { + LOGS(logger_, ERROR) << "[QNN EP] Failed to add supported node to QNN graph during EP's compile call: " + << status.ErrorMessage() << std::endl; + return status; } - - if (handled_node_units.count(&node_unit) != 0) { - continue; // Already handled. - } - - // Try to see if this node unit can be fused. - std::vector fused_nodes; - ORT_RETURN_IF_ERROR(TryFusions(fused_nodes, qnn_model_wrapper, node_unit, node_unit_map, - handled_node_units, logger_, false /*do_op_validation*/)); - - if (!fused_nodes.empty()) { - for (auto fused_node_unit : fused_nodes) { - handled_node_units.insert(fused_node_unit); - } - continue; - } - - LOGS(logger_, VERBOSE) << " node name: [" << node->Name() - << "] node optype: [" << op_type - << "] as part of the NodeUnit type: [" << node_unit.OpType() - << "] name: [" << node_unit.Name() - << "]"; - if (const auto* op_builder = GetOpBuilder(op_type)) { - ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(qnn_model_wrapper, node_unit, logger_)); - } - - handled_node_units.insert(&node_unit); } ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph."); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index c8537307ef3b..657224f68f71 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -239,6 +239,8 @@ bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name, std::string error_msg; bool rt = op_config_wrapper.QnnGraphOpValidation(qnn_interface_, backend_handle_, error_msg); if (!rt) { + // TODO(adrianlizarraga): Return a Status with the error message so that aggregated logs show a more + // specific validation error (instead of "failed to add node"). LOGS(logger_, WARNING) << error_msg; } return rt; @@ -617,6 +619,12 @@ Status QnnModelWrapper::UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& auto dst = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), unpacked_tensor.size()); auto src = gsl::make_span(reinterpret_cast(packed_int4_bytes.data()), packed_int4_bytes.size()); ORT_RETURN_IF_NOT(Int4x2::Unpack(dst, src), "Failed to unpack Tensor for QNN"); + + // NOTE: Masking off top 4 bits to workaround a QNN INT4 accuracy bug. + // Docs explicitly state that masking off top 4 bits should not be required. + for (size_t i = 0; i < dst.size(); i++) { + dst[i] &= 0x0F; // -3 (0b1111_1101) becomes 13 (0b0000_1101) + } } else if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); const size_t num_elems = shape.Size(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h new file mode 100644 index 000000000000..f9ef01411310 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "core/common/logging/logging.h" +#include "core/framework/node_unit.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a group of NodeUnits that QNN EP translates into a core QNN operator. Can represent a single NodeUnit +/// or a fusion of multiple NodeUnits (e.g., DQ* -> Conv -> Relu -> Q). +/// +class IQnnNodeGroup { + public: + virtual ~IQnnNodeGroup() = default; + + // Returns an OK status if this IQnnNodeGroup is supported by QNN. + virtual Status IsSupported(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const = 0; + + // Adds this IQnnNodeGroup to the QNN model wrapper. + virtual Status AddToModelBuilder(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const = 0; + + // Returns a list of NodeUnits contained by this IQnnNodeGroup. + virtual gsl::span GetNodeUnits() const = 0; + + /// + /// Returns the "target" NodeUnit of the group. This is important for topological ordering of IQnnNodeGroups. + /// The target should be the first NodeUnit where all input paths (of the IQnnNodeGroup) converge. + /// For example, "Conv" should be the target NodeUnit for the following IQnnNodeGroup with 6 NodeUnits. + /// input0 -> DQ -> Conv -> Relu -> Q + /// ^ + /// | + /// input1 -> DQ ----+ + /// + /// Target NodeUnit in IQnnNodeGroup + virtual const NodeUnit* GetTargetNodeUnit() const = 0; + + // Returns a string representation of the IQnnNodeGroup's type. + virtual std::string_view Type() const = 0; +}; + +/// +/// Traverses the ONNX graph to create IQnnNodeGroup objects, each containing one or more NodeUnits. +/// The returned IQnnNodeGroup objects are sorted in topological order. +/// +/// Output vector into which the resulting IQnnNodeGroup objects are stored. +/// Contains reference to the ONNX GraphViewer and used for validaton on QNN +/// Maps a Node* to a NodeUnit* +/// The number of NodeUnits in the ONNX graph. +/// Logger +/// Status with potential error +Status GetQnnNodeGroups(/*out*/ std::vector>& qnn_node_groups, + QnnModelWrapper& qnn_model_wrapper, + const std::unordered_map& node_to_node_unit, + size_t num_node_units, + const logging::Logger& logger); +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc new file mode 100644 index 000000000000..813bba8a5952 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc @@ -0,0 +1,480 @@ +#include "core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h" + +#include +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/framework/node_unit.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" + +namespace onnxruntime { +namespace qnn { + +// Gets the scale, zero-point, and zero-point type for a QuantizeLinear node that uses per-tensor quantization. +static bool GetQScalarScaleZeroPoint(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& q_node_unit, + /*out*/ float& scale, + /*out*/ int32_t& zero_point, + /*out*/ int32_t& zp_data_type) { + assert(q_node_unit.OpType() == QUANTIZE_LINEAR); + const auto& q_inputs = q_node_unit.GetNode().InputDefs(); + + // Require an explicit zero-point input for now. + if (q_inputs.size() != 3 || !q_inputs[QDQ_ZERO_POINT_INPUT_IDX]->Exists()) { + return false; + } + + std::vector zero_points; + Status status = qnn_model_wrapper.UnpackZeroPoints(q_inputs[QDQ_ZERO_POINT_INPUT_IDX]->Name(), + zero_points, zp_data_type); + + // Should only have one zero-point (per-tensor). + if (!status.IsOK() || zero_points.size() != 1) { + return false; + } + zero_point = -zero_points[0]; // QNN zero-points are negated. + + std::vector scales; + status = qnn_model_wrapper.UnpackScales(q_inputs[QDQ_SCALE_INPUT_IDX]->Name(), scales); + + // Should only have one scale (per-tensor). + if (!status.IsOK() || scales.size() != 1) { + return false; + } + + scale = scales[0]; + return true; +} + +// Computes the floating point range (rmin, rmax) from a QuantizeLinear node's scale/zero-point. +static bool GetQRminRmax(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& q_node_unit, + /*out*/ float& rmin, + /*out*/ float& rmax) { + int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; + int32_t zero_point = 0; + float scale = 0.0f; + + if (!GetQScalarScaleZeroPoint(qnn_model_wrapper, q_node_unit, scale, zero_point, zp_data_type)) { + return false; + } + + switch (zp_data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + default: + return false; + } + + return true; +} + +// Returns true if the Clip in the sequence (Clip -> Q) can be removed because it is made redundant by the Q. +static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& clip_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger) { + assert(clip_node_unit.OpType() == "Clip" && q_node_unit.OpType() == QUANTIZE_LINEAR); + float rmin = 0.0f; + float rmax = 0.0f; + + if (!GetQRminRmax(qnn_model_wrapper, q_node_unit, rmin, rmax)) { + return false; + } + + float clip_min = std::numeric_limits::lowest(); + float clip_max = std::numeric_limits::max(); + + if (!onnxruntime::GetClipMinMax(qnn_model_wrapper.GetGraphViewer(), clip_node_unit.GetNode(), + clip_min, clip_max, logger)) { + return false; + } + + // The clip range must entirely overlap the quantization range (quantization can be smaller). + // Clip range: [------------------] + // Quant range: [-------------] + constexpr float epsilon = std::numeric_limits::epsilon(); + if ((epsilon < clip_min - rmin) || (epsilon < rmax - clip_max)) { + return false; + } + + return true; +} + +// Returns true if the Relu in the sequence (Relu -> Q) can be removed because it is made redundant by the Q. +static bool CanQRelaceRelu(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& q_node_unit) { + assert(q_node_unit.OpType() == QUANTIZE_LINEAR); + int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; + int32_t zero_point = 0; + float scale = 0.0f; + + if (!GetQScalarScaleZeroPoint(qnn_model_wrapper, q_node_unit, scale, zero_point, zp_data_type)) { + return false; + } + + // Relu is redundant if the zero-point is set to the smallest quantized value. + switch (zp_data_type) { + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_INT8: + return zero_point == static_cast(std::numeric_limits::lowest()); + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UINT8: + return zero_point == static_cast(std::numeric_limits::lowest()); + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_INT16: + return zero_point == static_cast(std::numeric_limits::lowest()); + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UINT16: + return zero_point == static_cast(std::numeric_limits::lowest()); + default: + return false; + } +} + +// Returns true if the Clip/Relu in the sequence (Clip/Relu -> Q) can be removed because it is made redundant by the Q. +static bool CanActivationBeRemoved(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& activation_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger) { + const std::string& activation_type = activation_node_unit.OpType(); + + if (activation_type == "Relu") { + return CanQRelaceRelu(qnn_model_wrapper, q_node_unit); + } + + if (activation_type == "Clip") { + return CanClipBeRemoved(qnn_model_wrapper, activation_node_unit, q_node_unit, logger); + } + + return false; +} + +// Returns the parent DQ nodes for a given node. +static std::vector FindParentDQNodes(const GraphViewer& graph_viewer, const Node& node) { + // Get all parent DQ nodes sorted by destination argument index. + std::vector parents(node.InputDefs().size(), nullptr); + for (auto it = node.InputEdgesBegin(); it != node.InputEdgesEnd(); it++) { + if (it->GetNode().OpType().compare(DEQUANTIZE_LINEAR) == 0) { + parents[it->GetDstArgIndex()] = &(it->GetNode()); + } + } + + // Remove all the nodes which are not in the graph_viewer + parents.erase(std::remove_if(parents.begin(), parents.end(), + [&graph_viewer](const Node* _node) { + return _node == nullptr || graph_viewer.GetNode(_node->Index()) == nullptr; + }), + parents.end()); + + return parents; +} + +// Gets the parent DQ nodes for the given Conv node. This fuction checks that the DQs are not a part of +// any other NodeUnit and that every Conv input comes from a parent DQ. +static bool GetConvDQs( + const GraphViewer& graph_viewer, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const Node& conv_node, + /*out*/ std::array& dq_node_units) { + if (conv_node.OpType() != "Conv" && conv_node.OpType() != "ConvTranspose") { + return false; + } + + // Count number of inputs to Conv node. + const auto& conv_inputs = conv_node.InputDefs(); + const size_t num_conv_inputs = std::count_if(conv_inputs.cbegin(), conv_inputs.cend(), + [](const NodeArg* input) { return input && input->Exists(); }); + + // Get the Conv's parent DQ nodes. + std::vector dq_nodes = FindParentDQNodes(graph_viewer, conv_node); + const size_t num_dqs = dq_nodes.size(); + + // Within a QDQ node group, a target node input is the only consumer of each DQ. + if ((num_conv_inputs != num_dqs) || (num_dqs > dq_node_units.size())) { + return false; + } + + dq_node_units.fill(nullptr); + for (size_t i = 0; i < num_dqs; i++) { + const Node* dq_node = dq_nodes[i]; + + // DQ must not produce a graph output. + if (!dq_node || graph_viewer.NodeProducesGraphOutput(*dq_node)) { + return false; + } + + // Conv should be the only consumer of a parent DQ. + const bool dq_has_single_output_edge_to_target = + dq_node->GetOutputEdgesCount() == 1 && + dq_node->OutputEdgesBegin()->GetNode().Index() == conv_node.Index(); + if (!dq_has_single_output_edge_to_target) { + return false; + } + + // DQ node must be part of a "standalone" NodeUnit. + const auto it = node_to_node_unit.find(dq_node); + if (it == node_to_node_unit.end()) { + return false; + } + const NodeUnit* dq_node_unit = it->second; + if (!dq_node_unit || node_unit_to_qnn_node_group.count(dq_node_unit) != 0) { + return false; + } + if (dq_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return false; + } + + dq_node_units[i] = dq_node_unit; + } + + return true; +} + +// Checks that the input and output data types are valid for a QDQ Conv. +static bool CheckQDQConvDataTypes(std::array& dq_node_units, + gsl::not_null q_node_unit) { + assert(q_node_unit->OpType() == QUANTIZE_LINEAR); + // input and output types need to be same + int32_t dt_input = dq_node_units[0]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_weight = dq_node_units[1]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_node_unit->GetNode().OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_input != dt_output) { + return false; + } + + if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { + if (dt_weight != dt_input) { + return false; + } + } + + if (dq_node_units[2] != nullptr) { // has bias + int32_t dt_bias = dq_node_units[2]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) { + return false; + } + } + + return true; +} + +// Utility function to either validate or create a quantized QNN Conv node. The function creates a temporary +// custom NodeUnit that excludes the Clip/Relu because it is redundant. This custom NodeUnit is passed to our +// existing Conv OpBuilder for creation or validation via QNN APIs. +#define ValidateOnQnn(qnn_model_wrapper, dq_node_units, conv_node_unit, q_node_unit, logger) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_units), (conv_node_unit), (q_node_unit), (logger), true) +#define CreateOnQnn(qnn_model_wrapper, dq_node_units, conv_node_unit, q_node_unit, logger) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_units), (conv_node_unit), (q_node_unit), (logger), false) +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span dq_node_units, + const NodeUnit* conv_node_unit, + const NodeUnit* q_node_unit, + const logging::Logger& logger, + bool validate) { + const size_t num_dqs = dq_node_units.size(); + constexpr size_t max_num_dqs = 3; + ORT_RETURN_IF_NOT(num_dqs == 2 || num_dqs == max_num_dqs, "QDQ Conv should have 2 or 3 DQs"); + ORT_RETURN_IF_NOT(conv_node_unit->OpType() == "Conv" && q_node_unit->OpType() == QUANTIZE_LINEAR, + "Expected Conv/ConvTranspose and QuantizeLinear but got ", conv_node_unit->OpType(), " and ", + q_node_unit->OpType()); + + std::array dq_nodes_buf = {}; + for (size_t i = 0; i < num_dqs; i++) { + dq_nodes_buf[i] = &dq_node_units[i]->GetNode(); + } + gsl::span dq_nodes(dq_nodes_buf.data(), num_dqs); + + std::array q_nodes = {&q_node_unit->GetNode()}; + const Node& target_node = conv_node_unit->GetNode(); + + // Populate NodeUnit inputs + std::vector inputs; + inputs.reserve(num_dqs); + for (const Node* dq_node : dq_nodes) { + const auto dq_inputs = dq_node->InputDefs(); + const auto& dq_attrs = dq_node->GetAttributes(); + + std::optional axis; + if (auto entry = dq_attrs.find("axis"); entry != dq_attrs.end()) { + axis = entry->second.i(); + } + + // quantization scale and zp are always the input[1, 2] + NodeUnitIODef::QuantParam quant_param{*dq_inputs[1], dq_inputs.size() == 3 ? dq_inputs[2] : nullptr, axis}; + inputs.push_back(NodeUnitIODef{*dq_inputs[0], quant_param}); + } + + // Populate NodeUnit outputs and output edges + std::vector outputs; + Node::EdgeSet output_edges; + for (const Node* q_node : q_nodes) { + const auto q_inputs = q_node->InputDefs(); + const auto& q_attrs = q_node->GetAttributes(); + const auto q_outputs = q_node->OutputDefs(); + + std::optional axis; + if (auto entry = q_attrs.find("axis"); entry != q_attrs.end()) { + axis = entry->second.i(); + } + + // quantization scale and zp are always the input[1, 2] + NodeUnitIODef::QuantParam quant_param{*q_inputs[1], q_inputs.size() == 3 ? q_inputs[2] : nullptr, axis}; + outputs.push_back(NodeUnitIODef{*q_outputs[0], quant_param}); + + // Gather output edges out of the Q node. + auto q_cur_edge = q_node->OutputEdgesBegin(); + auto q_end_edge = q_node->OutputEdgesEnd(); + for (; q_cur_edge != q_end_edge; ++q_cur_edge) { + output_edges.insert(Node::EdgeEnd{q_cur_edge->GetNode(), 0, q_cur_edge->GetDstArgIndex()}); + } + } + + NodeUnit custom_node_unit(dq_nodes, target_node, q_nodes, NodeUnit::Type::QDQGroup, + inputs, outputs, num_dqs, output_edges); + const auto* conv_op_builder = qnn::GetOpBuilder(custom_node_unit.OpType()); + if (conv_op_builder == nullptr) { + return Status::OK(); + } + + if (validate) { + return conv_op_builder->IsOpSupported(qnn_model_wrapper, custom_node_unit, logger); + } + + return conv_op_builder->AddToModelBuilder(qnn_model_wrapper, custom_node_unit, logger, validate); +} + +// Traverses graph to check if the given NodeUnit is part of a valid DQ* -> Conv -> Relu/Clip -> Q sequence. +// If so, returns a IQnnNodeGroup that contains the constituent NodeUnits. +std::unique_ptr ConvActivationFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + // Expect that this function is called with a standalone Conv or ConvTranspose. + const auto& conv_type = conv_node_unit.OpType(); + + if ((conv_type != "Conv" && conv_type != "ConvTranspose") || + (conv_node_unit.UnitType() != NodeUnit::Type::SingleNode)) { + return nullptr; + } + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + // Conv must have a single Relu or Clip child. + const std::array activation_op_types = {"Relu", "Clip"}; + const NodeUnit* activation_node_unit = GetOnlyChildOfType(graph_viewer, conv_node_unit, activation_op_types, + node_to_node_unit, node_unit_to_qnn_node_group); + if (activation_node_unit == nullptr) { + return nullptr; + } + + // Relu/Clip must have a single Q child. + const std::array q_op_types = {QUANTIZE_LINEAR}; + const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, *activation_node_unit, q_op_types, + node_to_node_unit, node_unit_to_qnn_node_group); + + if (q_node_unit == nullptr) { + return nullptr; + } + + // Check if Clip/Relu can be removed because the Q node provides an equivalent effect. + if (!CanActivationBeRemoved(qnn_model_wrapper, *activation_node_unit, *q_node_unit, logger)) { + return nullptr; + } + + // Create a QDQ node group with DQ* -> Conv -> Q + const Node& conv_node = conv_node_unit.GetNode(); + std::array dq_node_units = {}; + if (!GetConvDQs(graph_viewer, + node_to_node_unit, + node_unit_to_qnn_node_group, + conv_node, dq_node_units)) { + return nullptr; + } + + if (!CheckQDQConvDataTypes(dq_node_units, q_node_unit)) { + return nullptr; + } + + return std::make_unique(*dq_node_units[0], + *dq_node_units[1], + dq_node_units[2], + conv_node_unit, + *activation_node_unit, + *q_node_unit); +} + +ConvActivationFusion::ConvActivationFusion(const NodeUnit& dq_node_unit_0, + const NodeUnit& dq_node_unit_1, + const NodeUnit* dq_node_unit_2, + const NodeUnit& conv_node_unit, + const NodeUnit& activation_node_unit, + const NodeUnit& q_node_unit) + : node_units_{} { + size_t i = 0; + node_units_[i++] = &dq_node_unit_0; + node_units_[i++] = &dq_node_unit_1; + if (dq_node_unit_2 != nullptr) { + node_units_[i++] = dq_node_unit_2; + } + node_units_[i++] = &conv_node_unit; + node_units_[i++] = &activation_node_unit; + node_units_[i++] = &q_node_unit; + assert((!dq_node_unit_2 && i == 5) || (dq_node_unit_2 && i == 6)); +} + +Status ConvActivationFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2; + gsl::span dq_node_units(node_units_.data(), num_dqs); + + return ValidateOnQnn(qmw, dq_node_units, + node_units_[num_dqs], // Conv + node_units_[num_dqs + 2], // Q + logger); +} + +Status ConvActivationFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2; + gsl::span dq_node_units(node_units_.data(), num_dqs); + + return CreateOnQnn(qmw, dq_node_units, + node_units_[num_dqs], // Conv + node_units_[num_dqs + 2], // Q + logger); +} + +gsl::span ConvActivationFusion::GetNodeUnits() const { + const size_t num_node_units = node_units_.back() != nullptr ? 6 : 5; + return gsl::make_span(node_units_.data(), num_node_units); +} + +const NodeUnit* ConvActivationFusion::GetTargetNodeUnit() const { + const size_t conv_index = node_units_.back() != nullptr ? 3 : 2; + return node_units_[conv_index]; +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h new file mode 100644 index 000000000000..b604b25e943e --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of a DQ* -> Conv -> Relu/Clip -> Q sequence where the Relu (or Clip) is redundant +/// due to the quantization effects of the Q. This sequence is translated to a quantized QNN Conv. +/// All contained NodeUnits are of type SingleNode since they are not a part of an existing QDQ node unit. +/// +class ConvActivationFusion : public IQnnNodeGroup { + public: + ConvActivationFusion(const NodeUnit& dq_node_unit_0, + const NodeUnit& dq_node_unit_1, + const NodeUnit* dq_node_unit_2, + const NodeUnit& conv_node_unit, + const NodeUnit& activation_node_unit, + const NodeUnit& q_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ConvActivationFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "ConvActivationFusion"; } + + /// + /// Traverses graph to check if the given NodeUnit is part of a valid DQ* -> Conv -> Relu/Clip -> Q sequence. + /// If so, returns a IQnnNodeGroup that contains the constituent NodeUnits. + /// + /// Used for validation and to traverse/query the graph + /// Conv node unit (type SingleNode) that be part of the sequence. + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; // Last elem is nullptr if the optional bias DQ is missing. +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc new file mode 100644 index 000000000000..ce87ac4a3d21 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc @@ -0,0 +1,179 @@ +#include "core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h" + +#include +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/framework/node_unit.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" + +namespace onnxruntime { +namespace qnn { + +// Forward declarations. +#define ValidateOnQnn(qnn_model_wrapper, dq_node_unit, q_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_unit), (q_node_unit), true) +#define CreateOnQnn(qnn_model_wrapper, dq_node_unit, q_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_unit), (q_node_unit), false) +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, bool validate); +static bool IsDQQConversion(const GraphViewer& graph_viewer, const Node& dq_node, const Node& q_node); + +std::unique_ptr DQQFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + // Expect that this function is called with a standalone DQ. + if (dq_node_unit.OpType() != DEQUANTIZE_LINEAR || dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const Node& dq_node = dq_node_unit.GetNode(); + + // DQ must have a single Q child (1 output edge) and must not produce a graph output. + const std::array child_types = {QUANTIZE_LINEAR}; + const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, dq_node_unit, child_types, + node_to_node_unit, node_unit_to_qnn_node_group); + + if (q_node_unit == nullptr) { + return nullptr; + } + + // DQ and Q must have equal scale type and different zp type. + if (!IsDQQConversion(graph_viewer, dq_node, q_node_unit->GetNode())) { + return nullptr; + } + + if (Status status = ValidateOnQnn(qnn_model_wrapper, dq_node_unit, *q_node_unit); + !status.IsOK()) { + return nullptr; + } + + return std::make_unique(dq_node_unit, *q_node_unit); +} + +DQQFusion::DQQFusion(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit) + : node_units_{&dq_node_unit, &q_node_unit} { +} + +Status DQQFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return ValidateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +Status DQQFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return CreateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +gsl::span DQQFusion::GetNodeUnits() const { + return node_units_; +} + +const NodeUnit* DQQFusion::GetTargetNodeUnit() const { + return node_units_[0]; +} + +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + bool validate) { + assert(dq_node_unit.OpType() == DEQUANTIZE_LINEAR && q_node_unit.OpType() == QUANTIZE_LINEAR); + const auto& node_name = utils::GetNodeName(dq_node_unit); + const NodeUnitIODef& input_def = dq_node_unit.Inputs()[0]; + const NodeUnitIODef& output_def = q_node_unit.Outputs()[0]; + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_CONVERT, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {})); + } else { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(q_node_unit), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_CONVERT, + {input_def.node_arg.Name()}, + {output_def.node_arg.Name()}, + {}, + validate), + "Failed to add fused Convert node."); + } + + return Status::OK(); +} + +static bool IsDQQConversion(const GraphViewer& graph_viewer, const Node& dq_node, const Node& q_node) { + ConstPointerContainer> dq_input_defs = dq_node.InputDefs(); + ConstPointerContainer> q_input_defs = q_node.InputDefs(); + + auto is_scalar_shape = [](const NodeArg& input_arg) -> bool { + auto shape = input_arg.Shape(); + if (shape == nullptr) { + return false; + } + + auto dim_size = shape->dim_size(); + return dim_size == 0 || (dim_size == 1 && shape->dim(0).has_dim_value() && shape->dim(0).dim_value() == 1); + }; + + // Q/DQ contains optional input is not supported + // non-scalar Q/DQ scale and zero point needs are not supported + if (dq_input_defs.size() != QDQ_MAX_NUM_INPUTS || + q_input_defs.size() != QDQ_MAX_NUM_INPUTS || + !is_scalar_shape(*q_input_defs[QDQ_SCALE_INPUT_IDX]) || + !is_scalar_shape(*q_input_defs[QDQ_ZERO_POINT_INPUT_IDX]) || + !is_scalar_shape(*dq_input_defs[QDQ_SCALE_INPUT_IDX]) || + !is_scalar_shape(*dq_input_defs[QDQ_ZERO_POINT_INPUT_IDX])) { + return false; + } + + // if Q/DQ scale and zero point are not constant, return false + const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = + graph_viewer.GetConstantInitializer(dq_input_defs[QDQ_SCALE_INPUT_IDX]->Name()); + const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = + graph_viewer.GetConstantInitializer(q_input_defs[QDQ_SCALE_INPUT_IDX]->Name()); + const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = + graph_viewer.GetConstantInitializer(dq_input_defs[QDQ_ZERO_POINT_INPUT_IDX]->Name()); + const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = + graph_viewer.GetConstantInitializer(q_input_defs[QDQ_ZERO_POINT_INPUT_IDX]->Name()); + if (nullptr == q_zp_tensor_proto || + nullptr == dq_zp_tensor_proto || + nullptr == q_scale_tensor_proto || + nullptr == dq_scale_tensor_proto) { + return false; + } + + // All TensorProtos must have a data type + if (!q_zp_tensor_proto->has_data_type() || !dq_zp_tensor_proto->has_data_type() || + !q_scale_tensor_proto->has_data_type() || !dq_scale_tensor_proto->has_data_type()) { + return false; + } + + // check Q/DQ have same scale type and different zero point type + return (dq_zp_tensor_proto->data_type() != q_zp_tensor_proto->data_type()) && + (dq_scale_tensor_proto->data_type() == q_scale_tensor_proto->data_type()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h new file mode 100644 index 000000000000..90fe44c3af05 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of a DQ -> Q sequence that converts from one quantization type (e.g., uint8_t) to +/// another (e.g., uint16_t). This is translated into a QNN Convert operator, which is much faster than individual +/// ops. The DQ and Q are standalone NodeUnits that are not part of a QDQ node unit. +/// +class DQQFusion : public IQnnNodeGroup { + public: + DQQFusion(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(DQQFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "DQQFusion"; } + + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid DQ -> Q sequence. + /// If so, returns a IQnnNodeGroup that contains the DQ and Q NodeUnits. + /// + /// Used for validation and traverse/query the graph + /// DQ node unit that could start the DQ -> Q sequence + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc new file mode 100644 index 000000000000..76b172664648 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc @@ -0,0 +1,144 @@ +#include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" + +#include +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/framework/node_unit.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" + +namespace onnxruntime { +namespace qnn { + +// Forward declarations. +#define ValidateOnQnn(qnn_model_wrapper, hardsigmoid_node_unit, mul_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (hardsigmoid_node_unit), (mul_node_unit), true) +#define CreateOnQnn(qnn_model_wrapper, hardsigmoid_node_unit, mul_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (hardsigmoid_node_unit), (mul_node_unit), false) +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& hardsigmoid_node_unit, + const NodeUnit& mul_node_unit, bool validate); + +std::unique_ptr HardSigmoidMulFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + + // Looking for a standalone HardSigmoid to start the sequence. + if (hardsigmoid_node_unit.OpType() != "HardSigmoid" || + hardsigmoid_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + NodeAttrHelper hs_attr_helper(hardsigmoid_node_unit); + float alpha = hs_attr_helper.Get("alpha", 0.2f); + float beta = hs_attr_helper.Get("beta", 0.5f); + constexpr float req_alpha = 1.0f / 6.0f; + constexpr float req_beta = 0.5f; + constexpr float alpha_eps = std::numeric_limits::epsilon() * req_alpha; + constexpr float beta_eps = std::numeric_limits::epsilon() * req_beta; + + // Check for explicit values of alpha and beta. + if (std::abs(alpha - req_alpha) > alpha_eps || std::abs(beta - req_beta) > beta_eps) { + return nullptr; + } + + // HardSigmoid must have a single Mul child (1 output edge) and must not produce a graph output. + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const std::array child_types = {"Mul"}; + const NodeUnit* mul_node_unit = GetOnlyChildOfType(graph_viewer, hardsigmoid_node_unit, child_types, + node_to_node_unit, node_unit_to_qnn_node_group); + + if (mul_node_unit == nullptr) { + return nullptr; + } + + // Input to HardSigmoid must also be the other input to the Mul. + const Node& mul_node = mul_node_unit->GetNode(); + auto& hs_input_name = hardsigmoid_node_unit.Inputs()[0].node_arg.Name(); + const bool same_root_input = mul_node.InputDefs()[0]->Name() == hs_input_name || + mul_node.InputDefs()[1]->Name() == hs_input_name; + + if (!same_root_input) { + return nullptr; + } + + if (Status status = ValidateOnQnn(qnn_model_wrapper, hardsigmoid_node_unit, *mul_node_unit); + !status.IsOK()) { + return nullptr; + } + + return std::make_unique(hardsigmoid_node_unit, *mul_node_unit); +} + +HardSigmoidMulFusion::HardSigmoidMulFusion(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit) + : node_units_{&hardsigmoid_node_unit, &mul_node_unit} { +} + +Status HardSigmoidMulFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return ValidateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +Status HardSigmoidMulFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return CreateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +gsl::span HardSigmoidMulFusion::GetNodeUnits() const { + return node_units_; +} + +const NodeUnit* HardSigmoidMulFusion::GetTargetNodeUnit() const { + return node_units_[0]; +} + +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const NodeUnit& mul_node_unit, + bool validate) { + assert(hardsigmoid_node_unit.OpType() == "HardSigmoid" && mul_node_unit.OpType() == "Mul"); + const auto& node_name = utils::GetNodeName(hardsigmoid_node_unit); + const NodeUnitIODef& input_def = hardsigmoid_node_unit.Inputs()[0]; + const NodeUnitIODef& output_def = mul_node_unit.Outputs()[0]; + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_HARD_SWISH, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {})); + } else { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_HARD_SWISH, + {input_def.node_arg.Name()}, + {output_def.node_arg.Name()}, + {}, + validate), + "Failed to add fused HardSwish node."); + } + + return Status::OK(); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h new file mode 100644 index 000000000000..3b67f13492a4 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of a HardSigmoid -> Mul sequence that computes `x * HardSigmoid(x)`. +/// This is translated into a QNN HardSwish operator. +/// The contained NodeUnits are of type SingleNode since they are not a part of a QDQ node unit. +/// +class HardSigmoidMulFusion : public IQnnNodeGroup { + public: + HardSigmoidMulFusion(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(HardSigmoidMulFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "HardSigmoidMulFusion"; } + + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid HardSigmoid -> Mul sequence. + /// If so, returns a IQnnNodeGroup that contains the HardSigmoid and Mul NodeUnits. + /// + /// Used for validation and traverse/query the graph + /// HardSigmoid node unit that could start the sequence + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc new file mode 100644 index 000000000000..9fb9e815321c --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_node_group.h" + +#include +#include +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" + +namespace onnxruntime { +namespace qnn { + +/// +/// A IQnnNodeGroup class that wraps a single NodeUnit. Most NodeUnits in the ONNX graph will +/// be wrapped by this class. +/// +class QnnNodeUnitWrapper : public IQnnNodeGroup { + public: + explicit QnnNodeUnitWrapper(const NodeUnit& node_unit) : node_unit_(&node_unit) {} + ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeUnitWrapper); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override { + const std::string& op_type = node_unit_->OpType(); + const auto* op_builder = qnn::GetOpBuilder(op_type); + ORT_RETURN_IF_NOT(op_builder != nullptr, "Operators of type `", op_type, + "` are not supported by QNN EP.", op_type, " node `", + node_unit_->Name(), "` will not be assigned to QNN EP."); + + return op_builder->IsOpSupported(qmw, *node_unit_, logger); + } + + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override { + const std::string& op_type = node_unit_->OpType(); + const auto* op_builder = qnn::GetOpBuilder(op_type); + ORT_RETURN_IF_NOT(op_builder != nullptr, "[QNN EP]: Missing OpBuilder for OpType ", op_type); + return op_builder->AddToModelBuilder(qmw, *node_unit_, logger, /*do_op_validation*/ false); + } + + gsl::span GetNodeUnits() const override { + return gsl::span{&node_unit_, 1ULL}; + } + + const NodeUnit* GetTargetNodeUnit() const override { return node_unit_; } + std::string_view Type() const override { return "NodeUnit"; } + + private: + const NodeUnit* node_unit_; +}; + +/// +/// The type of a function that tries to fuse NodeUnits into a IQnnNodeGroup. +/// +using FusionFunc = std::unique_ptr (*)( + QnnModelWrapper&, + const NodeUnit&, + const std::unordered_map&, + const std::unordered_map&, + const logging::Logger&); + +/// +/// Given a starting NodeUnit, this function tries all possible fusions that start with that NodeUnit. +/// If successful, returns a IQnnNodeGroup object that represents the fusion of various NodeUnits. +/// Currently only handles standalone NodeUnits that are not in a QDQ unit but that can change in the future. +/// +/// QnnModelWrapper that contains the ONNX GraphViewer. Used for validation. +/// NodeUnit that potentially starts a fusion. +/// Maps a Node* to a NodeUnit* +/// Maps a NodeUnit* to a IQnnNodeGroup* +/// +/// IQnnNodeGroup representing the fusion or an empty std::unique_ptr +static std::unique_ptr TryQnnFusions( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& starting_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + // Maps a starting operator type to the fusion function. + static std::unordered_map fusions = { + {"DequantizeLinear", DQQFusion::TryFusion}, + {"HardSigmoid", HardSigmoidMulFusion::TryFusion}, + {"Conv", ConvActivationFusion::TryFusion}, + {"ConvTranspose", ConvActivationFusion::TryFusion}, + }; + + // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). + if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + auto iter = fusions.find(starting_node_unit.OpType()); + if (iter != fusions.end()) { + FusionFunc fusion_func = iter->second; + return fusion_func(qnn_model_wrapper, starting_node_unit, node_to_node_unit, + node_unit_to_qnn_node_group, logger); + } + return nullptr; +} + +// Traverses the ONNX Graph and groups NodeUnits into IQnnNodeGroup objects. Some IQnnNodeGroup objects +// represent a fusion of various NodeUnits. This function generates a vector of indices that +// represent the topological order of the qnn_node_groups. +static Status GetQnnNodeGroupsImpl(/*out*/ std::vector>& qnn_node_groups, + /*out*/ std::vector& sorted_qnn_node_group_indices, + QnnModelWrapper& qnn_model_wrapper, + const std::unordered_map& node_to_node_unit, + const size_t num_node_units, + const logging::Logger& logger) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const std::vector sorted_node_indices = graph_viewer.GetNodesInTopologicalOrder(); + + sorted_qnn_node_group_indices.reserve(num_node_units); + qnn_node_groups.reserve(num_node_units); + + std::unordered_map node_unit_to_qnn_node_group; + std::unordered_map fused_qnn_node_group_indices; + std::vector> sorted_node_units; + sorted_node_units.reserve(num_node_units); + + // Process just the fusions of NodeUnits first to ensure a correct topological order of all IQnnNodeGroups. + // This is the same approach taken by ORT utilities for grouping Nodes into NodeUnits. + for (NodeIndex node_index : sorted_node_indices) { + gsl::not_null node = graph_viewer.GetNode(node_index); + + // Get the NodeUnit associated with the node. + const auto node_unit_it = node_to_node_unit.find(node); + ORT_RETURN_IF_NOT(node_unit_it != node_to_node_unit.end(), "Could not find NodeUnit for Node ", node->Name()); + gsl::not_null node_unit = node_unit_it->second; + + // Skip this node if it is not the NodeUnit's target node to ensure NodeUnits are visited in topological order. + if (node != &node_unit->GetNode()) { + continue; + } + + sorted_node_units.push_back(node_unit); + + if (node_unit_to_qnn_node_group.count(node_unit) != 0) { + continue; // Already handled this node unit + } + + std::unique_ptr fused_node_group = TryQnnFusions(qnn_model_wrapper, *node_unit, + node_to_node_unit, node_unit_to_qnn_node_group, + logger); + + if (fused_node_group) { + const size_t index = qnn_node_groups.size(); + fused_qnn_node_group_indices[fused_node_group.get()] = index; + + for (const NodeUnit* fused_node_unit : fused_node_group->GetNodeUnits()) { + assert(fused_node_unit != nullptr); + node_unit_to_qnn_node_group.insert({fused_node_unit, fused_node_group.get()}); + } + + qnn_node_groups.push_back(std::move(fused_node_group)); + } + } + + // Create IQnnNodeGroups for the leftover NodeUnits that were not fused. + for (gsl::not_null node_unit : sorted_node_units) { + const auto it = node_unit_to_qnn_node_group.find(node_unit); + + if (it != node_unit_to_qnn_node_group.end()) { + // Already added this NodeUnit to a IQnnNodeGroup, so we'll skip it. + // However, if this NodeUnit is the "target" for the IQnnNodeGroup, then add its index to + // the sorted list of indices. + gsl::not_null fused_qnn_node_group = it->second; + if (node_unit == fused_qnn_node_group->GetTargetNodeUnit()) { + sorted_qnn_node_group_indices.push_back(fused_qnn_node_group_indices[fused_qnn_node_group]); + } + continue; + } + + const size_t index = qnn_node_groups.size(); + auto qnn_node_group = std::make_unique(*node_unit); + + node_unit_to_qnn_node_group.insert({node_unit, qnn_node_group.get()}); + qnn_node_groups.push_back(std::move(qnn_node_group)); + sorted_qnn_node_group_indices.push_back(index); + } + + assert(qnn_node_groups.size() == sorted_qnn_node_group_indices.size()); + + return Status::OK(); +} + +Status GetQnnNodeGroups(/*out*/ std::vector>& qnn_node_groups, + QnnModelWrapper& qnn_model_wrapper, + const std::unordered_map& node_to_node_unit, + const size_t num_node_units, + const logging::Logger& logger) { + std::vector sorted_qnn_node_group_indices; + std::vector> qnn_node_groups_holder; + ORT_RETURN_IF_ERROR(GetQnnNodeGroupsImpl(qnn_node_groups_holder, sorted_qnn_node_group_indices, qnn_model_wrapper, + node_to_node_unit, num_node_units, logger)); + + // Move IQnnNodeGroups to the output std::vector in sorted (topological) order. + qnn_node_groups.resize(0); + qnn_node_groups.reserve(qnn_node_groups_holder.size()); + for (auto index : sorted_qnn_node_group_indices) { + assert(index < qnn_node_groups_holder.size()); + std::unique_ptr qnn_node_group = std::move(qnn_node_groups_holder[index]); + qnn_node_groups.push_back(std::move(qnn_node_group)); + } + + assert(qnn_node_groups.size() == sorted_qnn_node_group_indices.size()); + + return Status::OK(); +} +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc new file mode 100644 index 000000000000..5548d7d37c37 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc @@ -0,0 +1,66 @@ +#include "core/providers/qnn/builder/qnn_node_group/utils.h" + +#include +#include +#include + +#include "core/graph/graph_viewer.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, + const NodeUnit& parent_node_unit, + gsl::span child_op_types, + const std::unordered_map& node_unit_map, + const std::unordered_map& qnn_node_group_map) { + const Node& parent_node = parent_node_unit.GetNode(); + + // Parent must have a single child (1 output edge) and must not produce a graph output. + if (parent_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(parent_node)) { + return nullptr; + } + + // Child must be of a valid type. + const Node& child_node = parent_node.OutputEdgesBegin()->GetNode(); + if (graph_viewer.GetNode(child_node.Index()) == nullptr) { + return nullptr; // Node is not in this GraphViewer + } + const std::string& child_type = child_node.OpType(); + bool is_valid_child_type = false; + + for (const auto& valid_op_type : child_op_types) { + if (valid_op_type == child_type) { + is_valid_child_type = true; + break; + } + } + + if (!is_valid_child_type) { + return nullptr; + } + + const auto child_node_unit_it = node_unit_map.find(&child_node); + if (child_node_unit_it == node_unit_map.end()) { + return nullptr; + } + const NodeUnit* child_node_unit = child_node_unit_it->second; + + // Check if child node has already been handled. Should not be the case if the calling + // fusion function has been called in topological order, but check to be safe. + if (qnn_node_group_map.count(child_node_unit) != 0) { + return nullptr; + } + + // child must not already be part of a QDQ NodeUnit (i.e., be standalone). + if (child_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + return child_node_unit; +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h new file mode 100644 index 000000000000..0d11d21906cc --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/graph/graph_viewer.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { +constexpr const char* QUANTIZE_LINEAR = "QuantizeLinear"; +constexpr const char* DEQUANTIZE_LINEAR = "DequantizeLinear"; +constexpr size_t QDQ_MAX_NUM_INPUTS = 3; +constexpr size_t QDQ_SCALE_INPUT_IDX = 1; +constexpr size_t QDQ_ZERO_POINT_INPUT_IDX = 2; + +/// +/// Utility function to get a child NodeUnit. The returned NodeUnit must be the parent's only child, must be +/// of the expected type, and must not be a part of another IQnnNodeGroup. +/// +/// GraphViewer containing all Nodes +/// Parent NodeUnit +/// Valid child types +/// Maps a Node to its NodeUnit +/// Maps a NodeUnit to its IQnnNodeGroup. +/// Used to check that the child has not already been added to another IQnnNodeGroup. +/// +const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, + const NodeUnit& parent_node_unit, + gsl::span child_op_types, + const std::unordered_map& node_unit_map, + const std::unordered_map& node_unit_to_qnn_node_group); + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index c2e500b8980a..d6c93a8f226e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -509,6 +509,9 @@ Status GetQminQmax(const Qnn_DataType_t qnn_data_type, } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { qmin = static_cast(std::numeric_limits::min()); qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_32) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); } else { ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); } @@ -519,15 +522,27 @@ Status GetQuantParams(float rmin, float rmax, const Qnn_DataType_t qnn_data_type, float& scale, - int& zero_point) { + int32_t& zero_point, + bool symmetric) { std::tie(rmin, rmax) = CheckMinMax(rmin, rmax); + if (symmetric) { + float abs_max = std::max(abs(rmax), abs(rmin)); + rmax = abs_max; + rmin = -abs_max; + } + float qmin = 0.0f; float qmax = 255.0f; ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); scale = (rmax - rmin) / (qmax - qmin); - const float initial_zero_point = qmin - (rmin / scale); - zero_point = static_cast(RoundHalfToEven(Saturate(qmax, qmin, initial_zero_point))); + float initial_zero_point = 0.0f; + if (symmetric) { + initial_zero_point = std::round(rmin + rmax) / 2; + } else { + initial_zero_point = qmin - (rmin / scale); + } + zero_point = static_cast(RoundHalfToEven(Saturate(qmax, qmin, initial_zero_point))); // To match QNN quantization definition zero_point = 0 - zero_point; return Status::OK(); @@ -541,7 +556,7 @@ double Dequantize(int32_t offset, float scale, const double quant_value) { Status Quantize(const double double_value, const float scale, - const int zero_point, + const int32_t zero_point, const Qnn_DataType_t qnn_data_type, int& quant_value) { int qmin = 0; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index 2392040d284b..aa4a27460563 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -93,13 +93,14 @@ Status GetQuantParams(float rmin, float rmax, const Qnn_DataType_t qnn_data_type, float& scale, - int& zero_point); + int32_t& zero_point, + bool symmetric = false); double Dequantize(int32_t offset, float scale, const double quant_value); Status Quantize(const double double_value, const float scale, - const int zero_point, + const int32_t zero_point, const Qnn_DataType_t qnn_data_type, int& quant_value); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 0ddaa9769421..fc64d63ede33 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -16,10 +16,10 @@ #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/partitioning_utils.h" -#include "core/providers/qnn/builder/qnn_fusions.h" #include "core/providers/partitioning_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/framework/run_options.h" @@ -199,6 +199,13 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio context_cache_path_cfg_ = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_cfg_; + + // For the case that workaround QNN context PD memory limit, user need split the model into pieces and + // generate the QNN context model separately. + // It could happen that the generated EPContext node in separate graph has same node name. + // User can set this context_node_name_prefix for each split pieces to avoid that happens. + context_node_name_prefix_ = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextNodeNamePrefix, ""); + LOGS_DEFAULT(VERBOSE) << "User specified QNN context node name prefix: " << context_node_name_prefix_; } static const std::string BACKEND_PATH = "backend_path"; @@ -405,25 +412,35 @@ QNNExecutionProvider::~QNNExecutionProvider() { #endif } -bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - const logging::Logger& logger) const { - const std::string& op_type = node_unit.OpType(); - bool supported = false; - const auto* op_builder = qnn::GetOpBuilder(op_type); - if (op_builder == nullptr) { - LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP." - << node_unit.OpType() << " node `" << node_unit.Name() - << "` will not be assigned to QNN EP."; - } else { - auto status = op_builder->IsOpSupported(qnn_model_wrapper, - node_unit, logger); - if (Status::OK() != status) { - LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name() - << "` is not supported: " << status.ErrorMessage(); +// Logs information about the supported/unsupported nodes. +static void LogNodeSupport(const logging::Logger& logger, + logging::Severity log_severity, + logging::DataType log_data_type, + const onnxruntime::CodeLocation& call_site, + const qnn::IQnnNodeGroup& qnn_node_group, + Status support_status) { + if (!logger.OutputIsEnabled(log_severity, log_data_type)) { + return; + } + + std::ostringstream oss; + oss << (support_status.IsOK() ? "Validation PASSED " : "Validation FAILED ") << "for nodes (" + << qnn_node_group.Type() << "):" << std::endl; + for (const NodeUnit* node_unit : qnn_node_group.GetNodeUnits()) { + for (const Node* node : node_unit->GetAllNodesInGroup()) { + oss << "\tOperator type: " << node->OpType() + << " Node name: " << node->Name() + << " Node index: " << node->Index() << std::endl; } - supported = (Status::OK() == status); } - return supported; + if (!support_status.IsOK()) { + oss << "\tREASON : " << support_status.ErrorMessage() << std::endl; + } + + logging::Capture(logger, log_severity, logging::Category::onnxruntime, + log_data_type, call_site) + .Stream() + << oss.str(); } std::unordered_set @@ -462,68 +479,33 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, initializer_input_lookup, qnn_backend_manager_->GetQnnBackendType()); - std::unordered_set handled_node_units; - handled_node_units.reserve(node_unit_size); - - auto add_supported_nodes = [](std::unordered_set& supported_nodes, const NodeUnit* node_unit) { - for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) { - supported_nodes.insert(node_in_group); - } - }; - - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (size_t i = 0; i < node_indices.size(); i++) { - gsl::not_null node(graph_viewer.GetNode(node_indices[i])); - - // Get the node_unit associated with the node. Note that the node may not be the node_unit's target node. - const NodeUnit* node_unit = node_unit_map.at(node); - - // Visiting 'nodes' in topological order does not guarantee that 'node_units' are - // also visited in topological order. Skip this node if it is not the node_unit's target node - // to ensure 'node_units' are visited in topological order. - if (node != &node_unit->GetNode()) { - continue; - } - - if (handled_node_units.count(node_unit) != 0) { - continue; // Already handled this node unit - } + std::vector> qnn_node_groups; + qnn_node_groups.reserve(node_unit_size); - // Try to see if this node unit can be fused. - std::vector fused_nodes; - Status fusion_status = TryFusions(fused_nodes, qnn_model_wrapper, *node_unit, node_unit_map, - handled_node_units, logger, true /*do_op_validation*/); + if (Status status = qnn::GetQnnNodeGroups(qnn_node_groups, qnn_model_wrapper, + node_unit_map, node_unit_size, logger); + !status.IsOK()) { + LOGS(logger, ERROR) << status.ErrorMessage(); + return {}; + } - if (!fusion_status.IsOK()) { - LOGS(logger, WARNING) << "Failed to apply fusion: " << fusion_status.ErrorMessage(); - handled_node_units.insert(node_unit); - continue; - } + for (const std::unique_ptr& qnn_node_group : qnn_node_groups) { + Status status = qnn_node_group->IsSupported(qnn_model_wrapper, logger); + const bool supported = status.IsOK(); - if (!fused_nodes.empty()) { - for (auto fused_node_unit : fused_nodes) { - handled_node_units.insert(fused_node_unit); - add_supported_nodes(supported_nodes, fused_node_unit); - } - continue; + constexpr auto log_severity = logging::Severity::kVERBOSE; + constexpr auto log_data_type = logging::DataType::SYSTEM; + if (logger.OutputIsEnabled(log_severity, log_data_type)) { + LogNodeSupport(logger, log_severity, log_data_type, ORT_WHERE, *qnn_node_group, status); } - // Couldn't fuse the node unit. See if it is supported by itself. - const bool supported = IsNodeSupported(qnn_model_wrapper, *node_unit, logger); - LOGS(logger, VERBOSE) << "Node supported: [" << supported - << "] index: [" << node->Index() - << "] name: [" << node->Name() - << "] Operator type: [" << node->OpType() - << "] as part of the NodeUnit type: [" << node_unit->OpType() - << "] index: [" << node_unit->Index() - << "] name: [" << node_unit->Name() - << "]"; - if (supported) { - add_supported_nodes(supported_nodes, node_unit); + for (const NodeUnit* node_unit : qnn_node_group->GetNodeUnits()) { + for (const Node* node : node_unit->GetAllNodesInGroup()) { + supported_nodes.insert(node); + } + } } - - handled_node_units.insert(node_unit); } return supported_nodes; @@ -565,7 +547,8 @@ static void PartitionCtxModel(const onnxruntime::GraphViewer& graph_viewer, supported_groups.begin(), supported_groups.end(), std::back_inserter(result), [&](const auto& supported_partition) { - return utils::MakeComputeCapability(graph_viewer, supported_partition, gen_metadef_name, QNN); + return utils::MakeComputeCapability(graph_viewer, supported_partition, gen_metadef_name, QNN, + /*drop_constant_initializers*/ false); // TODO: could this be set to true? }); const size_t num_of_partitions = result.size(); @@ -612,7 +595,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer const auto gen_metadef_name = [&]() { uint64_t model_hash; int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); - return MakeString(QNN, "_", model_hash, "_", metadef_id); + return MakeString(QNN, context_node_name_prefix_, "_", model_hash, "_", metadef_id); }; // For model with EPContext, make sure each partition only has one single EPContext node @@ -660,7 +643,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer // Create partitions from supported nodes. std::vector> partitions = utils::CreateSupportedPartitions( - graph_viewer, supported_nodes, {}, gen_metadef_name, QNN, kQnnExecutionProvider, &node_unit_map, true); + graph_viewer, supported_nodes, {}, gen_metadef_name, QNN, kQnnExecutionProvider, &node_unit_map); // Filter out partitions that consist of a single QuantizeLinear or DequantizeLinear node. // We also count the number of supported nodes in all valid partitions. diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index e7419dabb14d..4c48370492ef 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -53,9 +53,6 @@ class QNNExecutionProvider : public IExecutionProvider { Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; private: - bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - const logging::Logger& logger) const; - std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, const size_t node_unit_size, @@ -80,6 +77,7 @@ class QNNExecutionProvider : public IExecutionProvider { std::unordered_map> qnn_models_; bool context_cache_enabled_ = false; std::string context_cache_path_cfg_ = ""; + std::string context_node_name_prefix_ = ""; bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. bool qnn_context_embed_mode_ = true; int32_t vtcm_size_in_mb_ = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 382b3ac93252..a9394838aa78 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -388,6 +388,7 @@ struct ProviderHost { virtual ONNX_NAMESPACE::TensorProto* AttributeProto__add_tensors(ONNX_NAMESPACE::AttributeProto* p) = 0; // GraphProto + virtual std::unique_ptr GraphProto__construct() = 0; virtual void GraphProto__operator_delete(ONNX_NAMESPACE::GraphProto* p) = 0; virtual void GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index de6c1da1d643..242c7126f327 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -146,6 +146,7 @@ struct AttributeProto final { }; struct GraphProto final { + static std::unique_ptr Create() { return g_host->GraphProto__construct(); } static void operator delete(void* p) { g_host->GraphProto__operator_delete(reinterpret_cast(p)); } void operator=(const GraphProto& v) { return g_host->GraphProto__operator_assign(this, v); } diff --git a/onnxruntime/core/providers/vitisai/imp/capability.cc b/onnxruntime/core/providers/vitisai/imp/capability.cc index 58522a45a151..6d188076fe61 100644 --- a/onnxruntime/core/providers/vitisai/imp/capability.cc +++ b/onnxruntime/core/providers/vitisai/imp/capability.cc @@ -51,7 +51,11 @@ GetComputeCapabilityOps(const onnxruntime::GraphViewer& graph, std::vector node_indexs = graph.GetNodesInTopologicalOrder(); node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_nodes_included_eps.count(index) > 0; }), node_indexs.end()); - node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_support_optypes_by_eps.count(graph.GetNode(index)->OpType()) == 0; }), node_indexs.end()); + node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), + [&](NodeIndex index) { + auto node = graph.GetNode(index); + return all_support_optypes_by_eps.count(node->Domain() + ":" + node->OpType()) == 0; }), + node_indexs.end()); std::vector> result; for (auto& n : node_indexs) { diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 8c1dce0d3dc1..df47fa5cee4a 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -55,10 +55,15 @@ struct OrtVitisAIEpAPI { uint32_t (*vaip_get_version)(); void (*get_backend_compilation_cache)(const std::string& model_path, const onnxruntime::Graph& graph, const char* json_config, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data); void (*restore_backend_compilation_cache)(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path); + void (*create_ep_context_nodes)( + onnxruntime::Graph& ep_context_graph, + const std::vector>& eps, + vaip_core::DllSafe>* ret_value) = nullptr; void Ensure() { if (handle_) return; auto& env = Provider_GetHost()->Env__Default(); + auto& logger = *Provider_GetHost()->LoggingManager_GetDefaultLogger(); #ifdef _WIN32 // this dll is already linked to the executable, normally a test program handle_ = reinterpret_cast(GetModuleHandle(TEXT("onnxruntime_vitisai_ep.dll"))); @@ -81,6 +86,10 @@ struct OrtVitisAIEpAPI { (void**)&vaip_get_version); ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "get_compilation_cache", (void**)&get_backend_compilation_cache)); ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "restore_compilation_cache", (void**)&restore_backend_compilation_cache)); + status1 = (env.GetSymbolFromLibrary(handle_, "create_ep_context_nodes", (void**)&create_ep_context_nodes)); + if (!status1.IsOK()) { + LOGS(logger, WARNING) << "create_ep_context_nodes is not defined, please upgrade onnxruntime_vitisai_ep.dll. However, it still works."; + } } private: @@ -146,6 +155,24 @@ void restore_backend_compilation_cache(const std::string& cache_dir, const std:: s_library_vitisaiep.restore_backend_compilation_cache(cache_dir, cache_key, cache_data, model_path); } +bool has_create_ep_context_nodes() { + return s_library_vitisaiep.create_ep_context_nodes != nullptr; +} + +std::optional> create_ep_context_nodes( + onnxruntime::Graph& ep_context_graph, + const std::vector>& eps) { + if (s_library_vitisaiep.create_ep_context_nodes) { + vaip_core::DllSafe> nodes; + s_library_vitisaiep.create_ep_context_nodes(ep_context_graph, eps, &nodes); + if (nodes.get()) { + auto ret = std::vector(*nodes); + return ret; + } + } + return std::nullopt; +} + struct MyCustomOpKernel : OpKernel { MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { op_kernel_ = @@ -173,7 +200,7 @@ void create_kernel_registry(std::vector domains) { auto def_builder = KernelDefBuilder::Create(); def_builder->SetName(op->GetName(op)); def_builder->SetDomain(domain->domain_.c_str()); - def_builder->SinceVersion(1); + def_builder->SinceVersion(op->GetStartVersion(op), op->GetEndVersion(op)); if (op->version > 12) { auto input_count = op->GetInputTypeCount(op); for (auto i = 0u; i < input_count; i++) { @@ -183,7 +210,7 @@ void create_kernel_registry(std::vector domains) { def_builder->Provider(onnxruntime::kVitisAIExecutionProvider); KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { - // out = std::make_unique(info, *op); + out = std::make_unique(info, *op); return Status::OK(); }; std::ignore = s_kernel_registry_vitisaiep->Register(KernelCreateInfo(def_builder->Build(), kernel_create_fn)); @@ -405,6 +432,29 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { graph.AddInitializedTensor(tensor); }; + the_global_api.get_model_path = [](const Graph& graph) -> const std::filesystem::path& { + return graph.ModelPath(); + }; + + the_global_api.create_empty_model = [](const std::filesystem::path& path, const std::vector>& opset) -> Model* { + auto model_proto = ONNX_NAMESPACE::ModelProto::Create(); + auto graph_proto = ONNX_NAMESPACE::GraphProto::Create(); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + for (const auto& op : opset) { + auto* opset_import = model_proto->add_opset_import(); + *(opset_import->mutable_domain()) = op.first; + opset_import->set_version(op.second); + } + std::ignore = model_proto->mutable_graph(); // create a graph + auto& logger = logging::LoggingManager::DefaultLogger(); + auto model = Model::Create(std::move(*model_proto), path, nullptr, logger); + return model.release(); + }; + + the_global_api.graph_set_inputs = [](Graph& graph, gsl::span inputs) { + graph.SetInputs(inputs); + }; + if (!s_library_vitisaiep.vaip_get_version) { return reinterpret_cast(&(the_global_api.host_)); } else { diff --git a/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h b/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h index d34f7095b704..5d020e00ff5b 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h @@ -26,6 +26,17 @@ class ExecutionProvider { virtual DllSafe> get_meta_def_constant_initializer() const = 0; virtual std::unique_ptr compile() const = 0; + + public: + inline void set_fused_node(const onnxruntime::Node* fused_node) { + fused_node_ = fused_node; + } + inline const onnxruntime::Node* get_fused_node() const { + return fused_node_; + } + + private: + const onnxruntime::Node* fused_node_ = nullptr; }; class CustomOp { diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index 3fdbc60bb0ee..ae2a513a98e3 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -9,10 +9,14 @@ #include "vaip/my_ort.h" #include "vaip/dll_safe.h" #include "vaip/custom_op.h" - +#include void initialize_vitisai_ep(); vaip_core::DllSafe>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options); std::shared_ptr get_kernel_registry_vitisaiep(); const std::vector& get_domains_vitisaiep(); void get_backend_compilation_cache(const onnxruntime::PathString& model_path_str, const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::ProviderOptions& options, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data); void restore_backend_compilation_cache(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path); +std::optional> create_ep_context_nodes( + onnxruntime::Graph& ep_context_graph, + const std::vector>& eps); +bool has_create_ep_context_nodes(); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 334673989048..e6aacfe1f027 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -8,12 +8,13 @@ #include #include #include +#include struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (3u) -#define VAIP_ORT_API_MINOR (1u) +#define VAIP_ORT_API_MAJOR (4u) +#define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { uint32_t magic; // 'VAIP' or something else to make sure the following field @@ -222,7 +223,11 @@ struct OrtApiForVaip { const std::vector& data); // [88] TensorProto* (*tensor_proto_new_bf16)( const std::string& name, const std::vector& shape, - const std::vector& data); // [89] + const std::vector& data); // [89] + const std::filesystem::path& (*get_model_path)(const Graph& graph); // [90] + Model* (*create_empty_model)(const std::filesystem::path& path, const std::vector>& opset); //[91] + void (*graph_set_inputs)(Graph& graph, + gsl::span inputs); // [92] }; #ifndef USE_VITISAI diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 036831df7a9c..756bda2199e8 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -44,7 +44,7 @@ VitisAIExecutionProvider::VitisAIExecutionProvider( void VitisAIExecutionProvider::CreateKernelRegistry() { for (const auto& domain : get_domains_vitisaiep()) { for (const auto* op : domain->custom_ops_) { - vitisai_optypes_.insert(op->GetName(op)); + vitisai_optypes_.insert(domain->domain_ + ":" + op->GetName(op)); } } } @@ -58,8 +58,15 @@ const InlinedVector VitisAIExecutionProvider::GetEpContextNodes() c // All preconditions are supposed to have happened. if (p_ep_ctx_model_) { auto& graph = p_ep_ctx_model_->MainGraph(); - for (const auto* p_node : graph.Nodes()) { - ep_context_node_ptrs.push_back(p_node); + if (has_create_ep_context_nodes()) { + auto nodes = create_ep_context_nodes(graph, **execution_providers_); + if (nodes.has_value()) { + ep_context_node_ptrs.assign(nodes->begin(), nodes->end()); + } + } else { + for (const auto* p_node : graph.Nodes()) { + ep_context_node_ptrs.push_back(p_node); + } } } return ep_context_node_ptrs; @@ -100,7 +107,7 @@ void VitisAIExecutionProvider::FulfillEPContextEnablement( auto& ep_ctx_graph = p_ep_ctx_model_->MainGraph(); if (!ep_ctx_embed_mode_) { auto ep_ctx_cache_path_str = GetEPContextCacheFileLocation(ep_ctx_model_file_loc_, model_path_str_); - std::ofstream ep_ctx_cache_ofs(ep_ctx_cache_path_str.c_str(), std::ios::trunc); + std::ofstream ep_ctx_cache_ofs(ep_ctx_cache_path_str.c_str(), std::ios::trunc | std::ios::binary); if (!ep_ctx_cache_ofs.is_open()) { ORT_THROW("Failed to open a file to write EP context cache: ", ep_ctx_cache_path_str.c_str()); } @@ -136,7 +143,7 @@ std::vector> VitisAIExecutionProvider::GetCap info_["cacheDir"] = cache_dir; info_["cacheKey"] = cache_key; LOGS_DEFAULT(VERBOSE) << "Trying getting compilation cache from " << PathToUTF8String(ep_ctx_model_file_loc_); - auto ep_ctx_payload = RetrieveEPContextCache(graph_viewer.GetGraph(), ep_ctx_model_file_loc_, false); + auto ep_ctx_payload = RetrieveEPContextCache(graph_viewer.GetGraph(), ep_ctx_model_file_loc_, true); restore_backend_compilation_cache(cache_dir, cache_key, ep_ctx_payload, graph_viewer.ModelPath().string()); } else { if (fs::exists(ep_ctx_model_file_loc_) && fs::is_regular_file(ep_ctx_model_file_loc_) && ep_ctx_enabled_) { @@ -187,6 +194,7 @@ common::Status VitisAIExecutionProvider::Compile(const std::vectorexecution_providers_)[index]->set_fused_node(&fused_node_graph.fused_node.get()); compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) { auto* p = (**this->execution_providers_)[index]->compile().release(); *state = p; @@ -204,7 +212,7 @@ common::Status VitisAIExecutionProvider::Compile(const std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, - const emscripten::val& wnn_builder_, + const emscripten::val& wnn_builder, const WebnnDeviceType device_type, const logging::Logger& logger) { std::vector> supported_node_groups; @@ -103,7 +103,7 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v const auto* node(graph_viewer.GetNode(node_idx)); bool supported = false; // Firstly check if platform supports the WebNN op. - if (CheckSingleOp(node->OpType(), wnn_builder_, device_type)) { + if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) { LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; supported = IsNodeSupported(*node, graph_viewer, device_type, logger); } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 496f886e5a07..fc13ce201f2e 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -151,7 +151,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c // Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, - const emscripten::val& wnn_builder_, + const emscripten::val& wnn_builder, const WebnnDeviceType device_type, const logging::Logger& logger); static const InlinedHashMap op_map = { @@ -167,10 +167,11 @@ static const InlinedHashMap op_map = { {"Concat", {"concat", true}}, {"Conv", {"conv2d", true}}, {"ConvInteger", {"conv2dInteger", false}}, - {"ConvTranspose", {"convTranspose2d", false}}, + {"ConvTranspose", {"convTranspose2d", true}}, {"Cos", {"cos", true}}, {"Div", {"div", true}}, {"DequantizeLinear", {"dequantizeLinear", false}}, + {"Dropout", {"identity", true}}, {"DynamicQuantizeLinear", {"dynamicQuantizeLinear", false}}, {"Elu", {"elu", true}}, {"Equal", {"equal", true}}, @@ -241,14 +242,14 @@ static const InlinedHashMap op_map = { {"Where", {"where", true}}, }; -inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_, +inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder, const WebnnDeviceType device_type) { // Returns false if the op_type is not listed in the op_map. if (op_map.find(op_type) == op_map.end()) { return false; } // Returns false if the WebNN op has not been implemented in MLGraphBuilder in current browser. - if (!wnn_builder_[op_map.find(op_type)->second.opName].as()) { + if (!wnn_builder[op_map.find(op_type)->second.opName].as()) { return false; } // The current WebNN CPU (TFLite) backend supports a limited op list, and we'd rather diff --git a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc index af0f0133b497..626aaf5c71b7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc @@ -36,6 +36,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, NodeAttrHelper helper(node); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (op_type == "Elu") { options.set("alpha", helper.Get("alpha", 1.0f)); output = model_builder.GetBuilder().call("elu", input, options); @@ -46,20 +47,20 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options.set("beta", helper.Get("beta", 0.5f)); output = model_builder.GetBuilder().call("hardSigmoid", input, options); } else if (op_type == "HardSwish") { - output = model_builder.GetBuilder().call("hardSwish", input); + output = model_builder.GetBuilder().call("hardSwish", input, options); } else if (op_type == "LeakyRelu") { options.set("alpha", helper.Get("alpha", 0.0f)); output = model_builder.GetBuilder().call("leakyRelu", input, options); } else if (op_type == "Relu") { - output = model_builder.GetBuilder().call("relu", input); + output = model_builder.GetBuilder().call("relu", input, options); } else if (op_type == "Sigmoid") { - output = model_builder.GetBuilder().call("sigmoid", input); + output = model_builder.GetBuilder().call("sigmoid", input, options); } else if (op_type == "Softplus") { - output = model_builder.GetBuilder().call("softplus", input); + output = model_builder.GetBuilder().call("softplus", input, options); } else if (op_type == "Softsign") { - output = model_builder.GetBuilder().call("softsign", input); + output = model_builder.GetBuilder().call("softsign", input, options); } else if (op_type == "Tanh") { - output = model_builder.GetBuilder().call("tanh", input); + output = model_builder.GetBuilder().call("tanh", input, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index 1330a3e35487..05f3a742a377 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -40,28 +40,21 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, NodeAttrHelper helper(node); int64_t axis = helper.Get("axis", 0); const auto keep_dims = helper.Get("keepdims", 1); - const auto select_last_index = helper.Get("select_last_index", 0); axis = HandleNegativeAxis(axis, input_rank); - emscripten::val axes = emscripten::val::array(); - axes.call("push", static_cast(axis)); emscripten::val options = emscripten::val::object(); - options.set("axes", axes); options.set("keepDimensions", keep_dims == 1); - options.set("selectLastIndex", select_last_index == 1); - // TODO: use WebNN's opSupportLimits API to check the backend's supported output data types. - // If the backend doesn't support int64 output, we should use default int32 output data type - // then do a type casting (int32 -> int64) for the output. Refer to the CoreML EP for how to - // support int64 output. + // TODO(Honry): check whether int64 output data type is supported by WebNN opSupportLimits() API. options.set("outputDataType", "int64"); + options.set("label", node.Name()); emscripten::val output = emscripten::val::object(); const auto& op_type = node.OpType(); if (op_type == "ArgMax") { - output = model_builder.GetBuilder().call("argMax", input, options); + output = model_builder.GetBuilder().call("argMax", input, narrow(axis), options); } else if (op_type == "ArgMin") { - output = model_builder.GetBuilder().call("argMin", input, options); + output = model_builder.GetBuilder().call("argMin", input, narrow(axis), options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ArgMaxMinOpBuilder, unknown op: ", op_type); } @@ -81,15 +74,6 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia if (!GetShape(*input_defs[0], input_shape, logger)) return false; - // WebNN CPU backend only supports select_last_index = 0. - if (device_type == WebnnDeviceType::CPU) { - NodeAttrHelper helper(node); - const auto select_last_index = helper.Get("select_last_index", 0); - if (select_last_index) { - LOGS(logger, VERBOSE) << "ArgMax/ArgMin with select_last_index = 1 is not supported on WebNN CPU backend."; - return false; - } - } return true; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 23e19d594314..555de68cd60f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -35,18 +35,21 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + if (op_type == "Add") { - output = model_builder.GetBuilder().call("add", input0, input1); + output = model_builder.GetBuilder().call("add", input0, input1, options); } else if (op_type == "Sub") { - output = model_builder.GetBuilder().call("sub", input0, input1); + output = model_builder.GetBuilder().call("sub", input0, input1, options); } else if (op_type == "Mul") { - output = model_builder.GetBuilder().call("mul", input0, input1); + output = model_builder.GetBuilder().call("mul", input0, input1, options); } else if (op_type == "Div") { - output = model_builder.GetBuilder().call("div", input0, input1); + output = model_builder.GetBuilder().call("div", input0, input1, options); } else if (op_type == "Pow") { - output = model_builder.GetBuilder().call("pow", input0, input1); + output = model_builder.GetBuilder().call("pow", input0, input1, options); } else if (op_type == "PRelu") { - output = model_builder.GetBuilder().call("prelu", input0, input1); + output = model_builder.GetBuilder().call("prelu", input0, input1, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BinaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index a97d71b90de5..a08e1681a846 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -69,8 +69,11 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, node.Name(), " type: ", to_type); } + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = - model_builder.GetBuilder().call("cast", input, emscripten::val(operand_type)); + model_builder.GetBuilder().call("cast", input, emscripten::val(operand_type), options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc index e6403a4cd12d..b5c3206072d5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -53,6 +53,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, "GetClipMinMax failed"); options.set("minValue", minValue); options.set("maxValue", maxValue); + options.set("label", node.Name()); emscripten::val input = model_builder.GetOperand(input_name); emscripten::val output = model_builder.GetBuilder().call("clamp", input, options); diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index e4f98b09e03c..dedc76b80e97 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -42,8 +42,11 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, inputs.push_back(model_builder.GetOperand(input->Name())); } + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = - model_builder.GetBuilder().call("concat", emscripten::val::array(inputs), axis); + model_builder.GetBuilder().call("concat", emscripten::val::array(inputs), axis, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 320aaa03930f..76a8a178678d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -28,7 +28,7 @@ class ConvOpBuilder : public BaseOpBuilder { // Operator support related. private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + const WebnnDeviceType device_type, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; @@ -242,6 +242,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); ORT_RETURN_IF_ERROR(SetConvBaseOptions( model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger)); bool depthwise = false; @@ -276,7 +277,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (!is_nhwc || !is_constant_weight) { // The weight_shape has been appended 1's, reshape weight operand. std::vector new_shape = GetVecUint32FromVecInt64(weight_shape); - filter = model_builder.GetBuilder().call("reshape", filter, emscripten::val::array(new_shape)); + emscripten::val reshape_options = emscripten::val::object(); + reshape_options.set("label", node.Name() + "_reshape_filter"); + filter = model_builder.GetBuilder().call("reshape", + filter, + emscripten::val::array(new_shape), + reshape_options); } } @@ -293,6 +299,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N perm = {0, 2, 3, 1}; // L_0231 } transpose_options.set("permutation", emscripten::val::array(perm)); + transpose_options.set("label", node.Name() + "_transpose_filter"); filter = model_builder.GetBuilder().call("transpose", filter, transpose_options); } @@ -323,7 +330,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N std::vector output_shape; ORT_RETURN_IF_NOT(GetShape(*output_defs[0], output_shape, logger), "Cannot get output shape"); std::vector new_shape = GetVecUint32FromVecInt64(output_shape); - output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape)); + emscripten::val reshape_options = emscripten::val::object(); + reshape_options.set("label", node.Name() + "_reshape_output"); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(new_shape), + reshape_options); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); @@ -366,6 +378,22 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return false; } + // WebNN CPU backend (TFLite) only supports default dilations and group. + // https://source.chromium.org/chromium/chromium/src/+/main:services/webnn/tflite/graph_builder_tflite.cc;l=1040 + if (device_type == WebnnDeviceType::CPU && op_type == "ConvTranspose") { + NodeAttrHelper helper(node); + const auto dilations = helper.Get("dilations", std::vector{1, 1}); + const auto group = helper.Get("group", 1); + if (dilations[0] != 1 || (dilations.size() > 1 && dilations[1] != 1)) { + LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default dilation 1."; + return false; + } + if (group != 1) { + LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default group 1."; + return false; + } + } + return true; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc index 66d502a4e672..93a12a696cce 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc @@ -50,11 +50,22 @@ Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil std::vector target_shape{static_cast(input_shape[axis])}; target_shape.insert(target_shape.begin(), axis, 1); target_shape.insert(target_shape.end(), input_shape.size() - axis - 1, 1); - scale = model_builder.GetBuilder().call("reshape", scale, emscripten::val::array(target_shape)); + emscripten::val reshape_scale_options = emscripten::val::object(); + reshape_scale_options.set("label", node.Name() + "_reshape_scale"); + scale = model_builder.GetBuilder().call("reshape", + scale, + emscripten::val::array(target_shape), + reshape_scale_options); + emscripten::val reshape_zero_point_options = emscripten::val::object(); + reshape_zero_point_options.set("label", node.Name() + "_reshape_zero_point"); zero_point = model_builder.GetBuilder().call("reshape", - zero_point, emscripten::val::array(target_shape)); + zero_point, + emscripten::val::array(target_shape), + reshape_zero_point_options); } - output = model_builder.GetBuilder().call("dequantizeLinear", input, scale, zero_point); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + output = model_builder.GetBuilder().call("dequantizeLinear", input, scale, zero_point, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); diff --git a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc new file mode 100644 index 000000000000..469acbc7a7e1 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class DropoutOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; +}; + +// Add operator related. + +void DropoutOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // Skip ratio and training_mode if present. + for (size_t i = 1; i < node.InputDefs().size(); i++) { + const auto input_name = node.InputDefs()[i]->Name(); + model_builder.AddInitializerToSkip(input_name); + model_builder.AddInputToSkip(input_name); + } +} + +Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& output_defs = node.OutputDefs(); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + + // WebNN EP only supports test mode. So we don't need to care about other inputs or + // attributes about training mode. Simply use WebNN's identity op to copy the input. + emscripten::val output = model_builder.GetBuilder().call("identity", input, options); + + model_builder.AddOperand(output_defs[0]->Name(), std::move(output)); + + // If mask output is requested as output it will contain all ones (bool tensor). + if (output_defs.size() > 1) { + std::vector mask_shape; + ORT_RETURN_IF_NOT(GetShape(*output_defs[1], mask_shape, logger), "Cannot get mask output's shape"); + std::vector dims = GetVecUint32FromVecInt64(mask_shape); + + emscripten::val desc = emscripten::val::object(); + desc.set("dataType", "uint8"); + desc.set("dimensions", emscripten::val::array(dims)); + const auto num_elements = narrow(Product(mask_shape)); + emscripten::val ones_buffer = emscripten::val::global("Uint8Array").new_(num_elements); + ones_buffer.call("fill", 1); + + emscripten::val mask_output = model_builder.GetBuilder().call("constant", desc, ones_buffer); + + emscripten::val options = emscripten::val::object(); + options.set("label", output_defs[1]->Name() + "_identity"); + // Add additional identity op in case the mask is the output of a WebNN graph, + // beacuse WebNN does not support a constant operand as output. + mask_output = model_builder.GetBuilder().call("identity", mask_output, options); + model_builder.AddOperand(output_defs[1]->Name(), std::move(mask_output)); + } + return Status::OK(); +} + +// Operator support related. +bool DropoutOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + return true; +} + +void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc index 3b5f64584b82..55746bb1f61f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc @@ -31,8 +31,9 @@ Status DynamicQuantizaLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); - output_array = model_builder.GetBuilder().call("dynamicQuantizeLinear", input); + output_array = model_builder.GetBuilder().call("dynamicQuantizeLinear", input, options); for (size_t i = 0, count = output_array["length"].as(); i < count; i++) { model_builder.AddOperand(node.OutputDefs()[i]->Name(), std::move(output_array[i])); diff --git a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc index 9c75c00fa927..c8cea833983b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc @@ -53,10 +53,14 @@ Status ExpandOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector output_shape; ORT_RETURN_IF_NOT(GetBidirectionalBroadcastShape(input_shape, new_shape, output_shape), "Cannot get output shape."); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = model_builder.GetBuilder().call("expand", input, - emscripten::val::array(GetVecUint32FromVecInt64(output_shape))); + emscripten::val::array(GetVecUint32FromVecInt64(output_shape)), + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc index 31b1bd92a950..d0ece026a704 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/flatten_op_builder.cc @@ -52,8 +52,10 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, SafeInt(num_post_axis_elements)}; emscripten::val inputs = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call( - "reshape", inputs, emscripten::val::array(new_shape)); + "reshape", inputs, emscripten::val::array(new_shape), options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index 014a08616c44..23233539d34c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -42,6 +42,7 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); emscripten::val options = emscripten::val::object(); options.set("axis", axis); + options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call("gather", input, indices, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 53f885019ab2..bd452b118fe3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -39,6 +39,8 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val a = model_builder.GetOperand(node.InputDefs()[a_idx]->Name()); emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name()); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (op_type == "MatMul") { std::vector a_shape; if (!GetShape(*input_defs[a_idx], a_shape, logger)) { @@ -53,23 +55,34 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (a_shape.size() == 1) { extended_a_shape = true; a_shape.insert(a_shape.begin(), 1); + emscripten::val reshape_a_options = emscripten::val::object(); + reshape_a_options.set("label", node.Name() + "_reshape_a"); a = model_builder.GetBuilder().call("reshape", a, - emscripten::val::array(GetVecUint32FromVecInt64(a_shape))); + emscripten::val::array(GetVecUint32FromVecInt64(a_shape)), + reshape_a_options); } // If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. bool extended_b_shape = false; if (b_shape.size() == 1) { extended_b_shape = true; b_shape.push_back(1); + emscripten::val reshape_b_options = emscripten::val::object(); + reshape_b_options.set("label", node.Name() + "_reshape_b"); b = model_builder.GetBuilder().call("reshape", b, - emscripten::val::array(GetVecUint32FromVecInt64(b_shape))); + emscripten::val::array(GetVecUint32FromVecInt64(b_shape)), + reshape_b_options); } - output = model_builder.GetBuilder().call("matmul", a, b); + output = model_builder.GetBuilder().call("matmul", a, b, options); + emscripten::val reshape_output_options = emscripten::val::object(); + reshape_output_options.set("label", node.Name() + "_reshape_output"); // If the inputs are both 1D, reduce the output to a scalar. if (extended_a_shape && extended_b_shape) { - output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array()); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(), + reshape_output_options); } // After matrix multiplication the prepended 1 is removed. else if (extended_a_shape) { @@ -78,7 +91,10 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N new_shape.push_back(narrow(b_shape[i])); } new_shape.push_back(narrow(b_shape.back())); - output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape)); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(new_shape), + reshape_output_options); } // After matrix multiplication the appended 1 is removed. else if (extended_b_shape) { @@ -86,7 +102,10 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N for (size_t i = 0; i < a_shape.size() - 1; i++) { new_shape.push_back(narrow(a_shape[i])); } - output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape)); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(new_shape), + reshape_output_options); } } else if (op_type == "MatMulInteger") { emscripten::val a_zero_point = emscripten::val::null(); @@ -101,9 +120,13 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } else { b_zero_point = model_builder.GetZeroConstant("uint8"); } - output = model_builder.GetBuilder().call("matmulInteger", a, a_zero_point, b, b_zero_point); + output = model_builder.GetBuilder().call("matmulInteger", + a, + a_zero_point, + b, + b_zero_point, + options); } else { // Gemm - emscripten::val options = emscripten::val::object(); NodeAttrHelper helper(node); const auto transA = helper.Get("transA", 0); options.set("aTranspose", emscripten::val(transA == 1)); diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index e56e8f6a3eb6..23f3a938fee5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -33,16 +33,18 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (op_type == "Equal") { - output = model_builder.GetBuilder().call("equal", input0, input1); + output = model_builder.GetBuilder().call("equal", input0, input1, options); } else if (op_type == "Greater") { - output = model_builder.GetBuilder().call("greater", input0, input1); + output = model_builder.GetBuilder().call("greater", input0, input1, options); } else if (op_type == "GreaterOrEqual") { - output = model_builder.GetBuilder().call("greaterOrEqual", input0, input1); + output = model_builder.GetBuilder().call("greaterOrEqual", input0, input1, options); } else if (op_type == "Less") { - output = model_builder.GetBuilder().call("lesser", input0, input1); + output = model_builder.GetBuilder().call("lesser", input0, input1, options); } else if (op_type == "LessOrEqual") { - output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1); + output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 0168f5927354..1080fd0a3f94 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -43,22 +43,26 @@ Status MaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(op_type == "Max" || op_type == "Min", "MaxMinOpBuilder, unknown op: ", op_type); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (input_count == 1) { // For 1 input, just concat the single input as workaround. // TODO: use identity instead once it's available in WebNN. emscripten::val inputs = emscripten::val::array(); inputs.call("push", input0); - output = model_builder.GetBuilder().call("concat", inputs, 0); + output = model_builder.GetBuilder().call("concat", inputs, 0, options); } else { std::string webnn_op_name = op_type == "Max" ? "max" : "min"; emscripten::val input1 = model_builder.GetOperand(input_defs[1]->Name()); - output = model_builder.GetBuilder().call(webnn_op_name.c_str(), input0, input1); + output = model_builder.GetBuilder().call(webnn_op_name.c_str(), input0, input1, options); for (size_t input_index = 2; input_index < input_count; ++input_index) { emscripten::val next_input = model_builder.GetOperand(input_defs[input_index]->Name()); - output = model_builder.GetBuilder().call(webnn_op_name.c_str(), output, next_input); + emscripten::val next_options = emscripten::val::object(); + next_options.set("label", node.Name() + "_" + input_defs[input_index]->Name()); + output = model_builder.GetBuilder().call(webnn_op_name.c_str(), output, next_input, next_options); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index a2aa0df5586e..4d068baf35e7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -42,6 +42,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder const auto rank = input_shape.size(); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); std::vector scale_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape"); @@ -116,7 +117,12 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder new_shape.erase(insertion_point, insertion_point + excess_rank); *insertion_point = sum; } - input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); + emscripten::val reshape_input_options = emscripten::val::object(); + reshape_input_options.set("label", node.Name() + "_reshape_input"); + input = model_builder.GetBuilder().call("reshape", + input, + emscripten::val::array(new_shape), + reshape_input_options); } if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { @@ -126,8 +132,12 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder // Reshape back to the original output shape for 3D input. if (input_shape.size() != 4) { std::vector output_shape = GetVecUint32FromVecInt64(input_shape); - output = model_builder.GetBuilder().call( - "reshape", output, emscripten::val::array(output_shape)); + emscripten::val reshape_output_options = emscripten::val::object(); + reshape_output_options.set("label", node.Name() + "reshape_output"); + output = model_builder.GetBuilder().call("reshape", + output, + emscripten::val::array(output_shape), + reshape_output_options); } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc index bc90821ba4ed..071155a2fb37 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc @@ -73,6 +73,7 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); NodeAttrHelper helper(node); const auto pad_mode = helper.Get("mode", std::string("constant")); @@ -145,9 +146,12 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, starts.push_back(start_padding[i] >= 0 ? SafeInt(0) : SafeInt(-start_padding[i])); sizes.push_back(SafeInt(input_shape[i] + start_padding[i] + end_padding[i])); } + emscripten::val slice_options = emscripten::val::object(); + slice_options.set("label", node.Name() + "_slice_output"); output = model_builder.GetBuilder().call("slice", output, emscripten::val::array(starts), - emscripten::val::array(sizes)); + emscripten::val::array(sizes), + slice_options); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc index 8b3eecf35fcc..0af62dacedbd 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -59,6 +59,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); NodeAttrHelper helper(node); const auto kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index 461050849385..3e6d4d9820e9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -57,6 +57,7 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, NodeAttrHelper helper(node); const auto keep_dims = helper.Get("keepdims", 1); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); options.set("keepDimensions", keep_dims == 1); std::vector axes_data; diff --git a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc index b5005269b96a..a7911683f035 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc @@ -58,8 +58,13 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::transform(target_shape.cbegin(), target_shape.cend(), std::back_inserter(new_shape), [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call("reshape", - input, emscripten::val::array(new_shape)); + input, + emscripten::val::array(new_shape), + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index c4ca980fec71..2218c858951d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -106,6 +106,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); NodeAttrHelper helper(node); const auto mode = helper.Get("mode", "nearest"); if (mode == "linear") { diff --git a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc index 1552023d3f87..0eb7dafdffe4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc @@ -55,8 +55,15 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val sizes = emscripten::val::array(); sizes.call("push", slice_length); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + // Since WebNN doesn't support Shape op, we use constant + slice ops as workaround. - emscripten::val output = model_builder.GetBuilder().call("slice", shape_constant, starts, sizes); + emscripten::val output = model_builder.GetBuilder().call("slice", + shape_constant, + starts, + sizes, + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index fb452aec1c92..bef13841c646 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -97,9 +97,12 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, sizes.begin(), [](int64_t i, int64_t j) { return SafeInt(i - j); }); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call("slice", inputs, emscripten::val::array(starts), - emscripten::val::array(sizes)); + emscripten::val::array(sizes), + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 95c1dbd51806..798cfabae65d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -42,7 +42,9 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, int32_t axis = helper.Get("axis", default_axis); axis = static_cast(HandleNegativeAxis(axis, input_size)); - emscripten::val output = model_builder.GetBuilder().call("softmax", input, axis); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = model_builder.GetBuilder().call("softmax", input, axis, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc index ea3b8ef384dd..4c59b694d690 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc @@ -49,6 +49,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); const size_t rank = input_shape.size(); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); NodeAttrHelper helper(node); int32_t axis = helper.Get("axis", 0); diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc index 8e6feb62fa8c..5eff96873b8c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc @@ -54,7 +54,6 @@ Status SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); const auto input_rank = input_shape.size(); - emscripten::val options = emscripten::val::object(); std::vector axes_data; auto rank = input_rank; @@ -111,7 +110,12 @@ Status SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil "SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); } - output = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + output = model_builder.GetBuilder().call("reshape", + input, + emscripten::val::array(new_shape), + options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 841e2d18244d..2ed8330bf25b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -32,9 +32,11 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); emscripten::val input2 = model_builder.GetOperand(node.InputDefs()[2]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); emscripten::val output = emscripten::val::object(); if (op_type == "Where") { - output = model_builder.GetBuilder().call("where", input0, input1, input2); + output = model_builder.GetBuilder().call("where", input0, input1, input2, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "TernaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc index 3921b1da188c..03c88ad9db88 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc @@ -42,6 +42,7 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); std::vector permutation = GetVecUint32FromVecInt64(perm); options.set("permutation", emscripten::val::array(permutation)); emscripten::val output = model_builder.GetBuilder().call("transpose", input, options); diff --git a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc index e4b7021d49b3..0c818533918a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc @@ -46,6 +46,7 @@ Status TriangularOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val output = emscripten::val::object(); NodeAttrHelper helper(node); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); const bool upper = helper.Get("upper", 1); options.set("upper", upper); diff --git a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc index e0016de8e69b..061404c8a9ce 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc @@ -30,35 +30,37 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); emscripten::val output = emscripten::val::object(); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); if (op_type == "Abs") { - output = model_builder.GetBuilder().call("abs", input); + output = model_builder.GetBuilder().call("abs", input, options); } else if (op_type == "Ceil") { - output = model_builder.GetBuilder().call("ceil", input); + output = model_builder.GetBuilder().call("ceil", input, options); } else if (op_type == "Cos") { - output = model_builder.GetBuilder().call("cos", input); + output = model_builder.GetBuilder().call("cos", input, options); } else if (op_type == "Erf") { - output = model_builder.GetBuilder().call("erf", input); + output = model_builder.GetBuilder().call("erf", input, options); } else if (op_type == "Exp") { - output = model_builder.GetBuilder().call("exp", input); + output = model_builder.GetBuilder().call("exp", input, options); } else if (op_type == "Floor") { - output = model_builder.GetBuilder().call("floor", input); + output = model_builder.GetBuilder().call("floor", input, options); } else if (op_type == "Identity") { - output = model_builder.GetBuilder().call("identity", input); + output = model_builder.GetBuilder().call("identity", input, options); } else if (op_type == "Log") { - output = model_builder.GetBuilder().call("log", input); + output = model_builder.GetBuilder().call("log", input, options); } else if (op_type == "Neg") { - output = model_builder.GetBuilder().call("neg", input); + output = model_builder.GetBuilder().call("neg", input, options); } else if (op_type == "Not") { - output = model_builder.GetBuilder().call("logicalNot", input); + output = model_builder.GetBuilder().call("logicalNot", input, options); } else if (op_type == "Reciprocal") { - output = model_builder.GetBuilder().call("reciprocal", input); + output = model_builder.GetBuilder().call("reciprocal", input, options); } else if (op_type == "Sin") { - output = model_builder.GetBuilder().call("sin", input); + output = model_builder.GetBuilder().call("sin", input, options); } else if (op_type == "Sqrt") { - output = model_builder.GetBuilder().call("sqrt", input); + output = model_builder.GetBuilder().call("sqrt", input, options); } else if (op_type == "Tan") { - output = model_builder.GetBuilder().call("tan", input); + output = model_builder.GetBuilder().call("tan", input, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "UnaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 6b0e1495f552..b21f717eedc7 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -20,14 +20,20 @@ namespace onnxruntime { namespace webnn { ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - const emscripten::val& context, const emscripten::val& builder, - const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type) + const emscripten::val& context, const DataLayout preferred_layout, + const WebnnDeviceType wnn_device_type) : graph_viewer_(graph_viewer), logger_(logger), wnn_context_(context), - wnn_builder_(builder), preferred_layout_(preferred_layout), - wnn_device_type_(wnn_device_type) {} + wnn_device_type_(wnn_device_type) { + // Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build() + // is only allowed to be called once. + wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context); + if (!wnn_builder_.as()) { + ORT_THROW("Failed to create WebNN builder."); + } +} Status ModelBuilder::Initialize() { PreprocessInitializers(); @@ -332,6 +338,8 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { if (!wnn_graph.as()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph."); } + // Explicitly release the WebNN builder to free memory. + wnn_builder_ = emscripten::val::undefined(); model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_)); model->SetInputs(std::move(input_names_)); model->SetOutputs(std::move(output_names_)); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 6a1688f16d2a..b1561f009aa2 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -22,8 +22,8 @@ class IOpBuilder; class ModelBuilder { public: ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - const emscripten::val& context, const emscripten::val& builder, - const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type); + const emscripten::val& context, const DataLayout preferred_layout, + const WebnnDeviceType wnn_device_type); ~ModelBuilder() = default; Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; @@ -62,8 +62,8 @@ class ModelBuilder { const GraphViewer& graph_viewer_; const logging::Logger& logger_; - emscripten::val wnn_context_ = emscripten::val::object(); - emscripten::val wnn_builder_ = emscripten::val::object(); + emscripten::val wnn_context_ = emscripten::val::undefined(); + emscripten::val wnn_builder_ = emscripten::val::undefined(); DataLayout preferred_layout_; WebnnDeviceType wnn_device_type_; InlinedHashMap wnn_operands_; diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index dfe015725c12..862cf5ded15b 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -81,6 +81,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateConcatOpBuilder("Concat", op_registrations); } + { // Dropout + CreateDropoutOpBuilder("Dropout", op_registrations); + } + { // Quantize/Dequantize CreateDynamicQuantizeLinearOpBuilder("DynamicQuantizeLinear", op_registrations); CreateDequantizeLinearOpBuilder("DequantizeLinear", op_registrations); diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index 818ff094fb64..e11938d8fa40 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -26,6 +26,7 @@ void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_ void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateDequantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 0da0dfc6dfb2..1cd382c1e75e 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -38,10 +38,6 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } - wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); - if (!wnn_builder_.as()) { - ORT_THROW("Failed to create WebNN builder."); - } } WebNNExecutionProvider::~WebNNExecutionProvider() {} @@ -81,14 +77,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view const auto& logger = *GetLogger(); - if (!wnn_builder_.as()) { - // The GetCapability function may be called again after Compile due to the logic in the - // PartitionOnnxFormatModel function (see onnxruntime/core/framework/graph_partitioner.cc). - // We need to re-create the wnn_builder_ here to avoid it's been released in last Compile. - wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); + emscripten::val wnn_builder = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); + if (!wnn_builder.as()) { + ORT_THROW("Failed to create WebNN builder."); } - const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder_, wnn_device_type_, logger); + const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger); + wnn_builder = emscripten::val::undefined(); if (node_groups.empty()) { return result; @@ -218,9 +213,10 @@ common::Status WebNNExecutionProvider::Compile(const std::vector model; ORT_RETURN_IF_ERROR(builder.Compile(model)); + // Build map from input name to its index in input definitions. { InlinedHashMap input_map; @@ -329,9 +325,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vectortype); return onnxruntime::Status::OK(); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cc3a9943ca0a..5eed7c5c6f2b 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1603,6 +1603,11 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, logger, GraphPartitioner::Mode::kOrtFormatLoad)); +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // a compiling EP (e.g. CoreML) may copy initializers to its own memory. run the cleanup of unused initializers + // so that they can be freed. + ORT_RETURN_IF_ERROR(graph.RemovedUnusedInitializersOrtFormat()); +#endif return Status::OK(); } @@ -1610,7 +1615,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, Status ApplyOrtFormatModelRuntimeOptimizations( onnxruntime::Graph& graph, const logging::Logger& logger, const SessionOptions& session_options, const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { bool modified = false; for (int level = static_cast(TransformerLevel::Level2); @@ -1618,7 +1624,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations( ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, - optimizers_to_disable, intra_op_thread_pool); + optimizers_to_disable, intra_op_thread_pool, p_buffered_tensors); for (const auto& transformer : transformers) { ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); @@ -2007,7 +2013,8 @@ common::Status InferenceSession::Initialize() { const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); ORT_RETURN_IF_ERROR_SESSIONID_( ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, - cpu_ep, GetIntraOpThreadPoolToUse())); + cpu_ep, GetIntraOpThreadPoolToUse(), + session_state_->GetMutableBufferedTensors())); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } @@ -3170,7 +3177,8 @@ common::Status InferenceSession::AddPredefinedTransformers( if (use_full_build_optimizations) { return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, optimizers_to_disable_, - GetIntraOpThreadPoolToUse()); + GetIntraOpThreadPoolToUse(), + session_state_->GetMutableBufferedTensors()); } else { const auto sat_context = minimal_build_optimization_handling == @@ -3180,7 +3188,8 @@ common::Status InferenceSession::AddPredefinedTransformers( : SatApplyContextVariant{SatDirectApplicationContext{}}; return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, optimizers_to_disable_, - GetIntraOpThreadPoolToUse()); + GetIntraOpThreadPoolToUse(), + session_state_->GetMutableBufferedTensors()); } }(); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 5cf5ff9b3bd0..1a5484ddc005 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2396,7 +2396,7 @@ Second example, if we wanted to add and remove some members, we'd do this: In GetApi we now make it return ort_api_3 for version 3. */ -static constexpr OrtApi ort_api_1_to_19 = { +static constexpr OrtApi ort_api_1_to_20 = { // NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering. // Shipped as version 1 - DO NOT MODIFY (see above text for more information) @@ -2763,16 +2763,16 @@ static_assert(offsetof(OrtApi, SessionOptionsAppendExecutionProvider_OpenVINO_V2 static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeof(void*) == 279, "Size of version 18 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: -static_assert(std::string_view(ORT_VERSION) == "1.19.0", +static_assert(std::string_view(ORT_VERSION) == "1.20.0", "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); // 1. Update the hardcoded version string in above static_assert to silence it -// 2. If there were any APIs added to ort_api_1_to_19 above: +// 2. If there were any APIs added to ort_api_1_to_20 above: // a. Add the 'End of version #' markers (pattern above should be obvious) // b. Add a static_assert in the directly above list of version sizes to ensure nobody adds any more functions to the just shipped API version ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { if (version >= 1 && version <= ORT_API_VERSION) - return &ort_api_1_to_19; + return &ort_api_1_to_20; fprintf(stderr, "The requested API version [%u] is not available, only API versions [1, %u] are supported in this build." diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 1d21933e9cba..6d6940590602 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -496,6 +496,7 @@ struct ProviderHostImpl : ProviderHost { ONNX_NAMESPACE::TensorProto* AttributeProto__add_tensors(ONNX_NAMESPACE::AttributeProto* p) override { return p->add_tensors(); } // GraphProto (wrapped) + std::unique_ptr GraphProto__construct() override { return std::make_unique(); } void GraphProto__operator_delete(ONNX_NAMESPACE::GraphProto* p) override { delete p; } const ONNX_NAMESPACE::ValueInfoProto& GraphProto__input(const ONNX_NAMESPACE::GraphProto* p, int index) override { return p->input(index); } @@ -1781,6 +1782,7 @@ OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsToOrtCUDAProviderOptionsV2(const cuda_options_converted.cudnn_conv_use_max_workspace = 1; cuda_options_converted.enable_cuda_graph = 0; cuda_options_converted.prefer_nhwc = 0; + cuda_options_converted.fuse_conv_bias = 0; cuda_options_converted.cudnn_conv1d_pad_to_nc1d = 0; cuda_options_converted.enable_skip_layer_norm_strict_mode = 0; cuda_options_converted.use_ep_level_unified_stream = 0; @@ -1931,12 +1933,31 @@ void ORTSessionOptionsToOrtOpenVINOProviderOptions(ProviderOptions& ov_options, kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; if (disable_cpu_fallback) ov_options["disable_cpu_fallback"] = "true"; + + // values from session options will override the providerOptions Value + bool so_epctx_enable = session_options->config_options.GetConfigOrDefault( + kOrtSessionOptionEpContextEnable, "0") == "1"; + if (so_epctx_enable) + ov_options["so_export_ep_ctx_blob"] = "true"; + + std::string so_cache_path = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "").c_str(); + ov_options["so_epctx_path"] = so_cache_path; + + // Default embedMode is 1. Saving the compiled model contents as a Epctx node attribute + bool so_epctx_embed_mode = session_options->config_options.GetConfigOrDefault( + kOrtSessionOptionEpContextEmbedMode, "1") == "0"; + if (so_epctx_embed_mode) { + // defaults to true + ov_options["so_epctx_embed_mode"] = "false"; + } } std::shared_ptr OpenVINOProviderFactoryCreator::Create(ProviderOptions* provider_options_map, const SessionOptions* session_options) { - if (session_options) + // Append session options applicable for EP to EP Provider options. + if (session_options) { onnxruntime::ORTSessionOptionsToOrtOpenVINOProviderOptions(*provider_options_map, session_options); + } return s_library_openvino.Get().CreateExecutionProviderFactory(provider_options_map); } diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 679ccce7fb07..ffcd339c0ca3 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -40,6 +40,10 @@ #include // for CUDNN_MAJOR #endif +#if defined(USE_COREML) +#include "core/providers/coreml/coreml_provider_factory.h" +#endif + #include // Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct, @@ -1161,7 +1165,30 @@ std::unique_ptr CreateExecutionProviderInstance( #if !defined(__APPLE__) LOGS_DEFAULT(WARNING) << "CoreML execution provider can only be used to generate ORT format model in this build."; #endif - return onnxruntime::CoreMLProviderFactoryCreator::Create(0)->CreateProvider(); + uint32_t coreml_flags = 0; + + const auto it = provider_options_map.find(type); + if (it != provider_options_map.end()) { + const ProviderOptions& options = it->second; + auto flags = options.find("flags"); + if (flags != options.end()) { + const auto& flags_str = flags->second; + + if (flags_str.find("COREML_FLAG_USE_CPU_ONLY") != std::string::npos) { + coreml_flags |= COREMLFlags::COREML_FLAG_USE_CPU_ONLY; + } + + if (flags_str.find("COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES") != std::string::npos) { + coreml_flags |= COREMLFlags::COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES; + } + + if (flags_str.find("COREML_FLAG_CREATE_MLPROGRAM") != std::string::npos) { + coreml_flags |= COREMLFlags::COREML_FLAG_CREATE_MLPROGRAM; + } + } + } + + return onnxruntime::CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider(); #endif } else if (type == kXnnpackExecutionProvider) { #if defined(USE_XNNPACK) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 2f197cc7f31c..aab04485246d 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -418,6 +418,9 @@ def quantize_weight_per_channel_impl( zero_point_list = [] scale_list = [] quantized_per_channel_data_list = [] + weights_shape = list(weights.shape) + reshape_dims = list(weights_shape) # deep copy + reshape_dims[channel_axis] = 1 # only one per channel for reshape for i in range(channel_count): per_channel_data = weights.take(i, channel_axis) channel_override_index = i if i < num_channel_overrides else 0 @@ -460,17 +463,10 @@ def quantize_weight_per_channel_impl( zero_point_list.append(zero_point) scale_list.append(scale) - quantized_per_channel_data_list.append(quantized_per_channel_data) + quantized_per_channel_data_list.append(np.asarray(quantized_per_channel_data).reshape(reshape_dims)) # combine per_channel_data into one - weights_shape = list(weights.shape) - reshape_dims = list(weights_shape) # deep copy - reshape_dims[channel_axis] = 1 # only one per channel for reshape - quantized_weights = np.asarray(quantized_per_channel_data_list[0]).reshape(reshape_dims) - for i in range(1, len(quantized_per_channel_data_list)): - channel_weights = np.asarray(quantized_per_channel_data_list[i]).reshape(reshape_dims) - quantized_weights = np.concatenate((quantized_weights, channel_weights), channel_axis) - + quantized_weights = np.concatenate(quantized_per_channel_data_list, channel_axis) q_weight_name = weight_name + TENSOR_NAME_QUANT_SUFFIX zp_name = weight_name + "_zero_point" scale_name = weight_name + "_scale" diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 40a4a4d26dc1..cc8bd622df9b 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -797,8 +797,8 @@ def parse_args(): parser.add_argument( "--quant_format", default="QOperator", - type=QuantFormat, - choices=list(QuantFormat), + type=str, + choices=["QOperator", "QDQ"], help="QuantFormat {QOperator, QDQ}" "QOperator format quantizes the model with quantized operators directly." "QDQ format quantize the model by inserting DeQuantizeLinear before the MatMul.", @@ -814,7 +814,7 @@ def parse_args(): input_model_path = args.input_model output_model_path = args.output_model - quant_format = args.quant_format + quant_format = QuantFormat[args.quant_format] if os.path.exists(output_model_path): logger.error(f"file {output_model_path} already exists") diff --git a/onnxruntime/python/tools/quantization/operators/norm.py b/onnxruntime/python/tools/quantization/operators/norm.py index 8c4c6c78582a..10d96cc49855 100644 --- a/onnxruntime/python/tools/quantization/operators/norm.py +++ b/onnxruntime/python/tools/quantization/operators/norm.py @@ -12,7 +12,7 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - assert node.op_type == "InstanceNormalization" or node.op_type == "LayerNormalization" + assert node.op_type in {"InstanceNormalization", "LayerNormalization", "BatchNormalization"} # Input self.quantizer.quantize_activation_tensor(node.input[0]) diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index b00e830a2a36..caac829126e3 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -82,6 +82,7 @@ "Where": QDQWhere, "InstanceNormalization": QDQNormalization, "LayerNormalization": QDQNormalization, + "BatchNormalization": QDQNormalization, } diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index dc2b38f3928a..a9ff623fb696 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -691,6 +691,9 @@ def create_multihead_attention_node( return None # Add bias to inputs for MHA + # Bias for cross attention is not fully supported in DMMHA and cpu MHA kernels since they assume + # bias has been added to key and value when they are in BNSH format, so only bias for query is used. + # Need add checks if we found such assumption is not true. if not self.disable_multi_head_attention_bias: bias_name = self.create_combined_qkv_bias(q_add, k_add, v_add, mha_node_name) mha_inputs.append(bias_name) diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index 689b14ea9a68..979f872ac4c5 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -1,11 +1,11 @@ torch>=1.13.0 -transformers>=4.24.0 +transformers>=4.24.0,<= 4.42.4 openai-whisper>=20231117 ffmpeg-python datasets soundfile librosa -optimum +optimum<=1.21.2 onnxruntime-extensions>=0.9.0 onnx==1.16.1 protobuf==3.20.2 diff --git a/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc b/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc index f700e3100301..027d4b3fff1b 100644 --- a/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc +++ b/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. // BiasDropout kernel is only implemented for CUDA/ROCM -#if defined(USE_CUDA) || defined(USE_ROCM) +#if (defined(USE_CUDA) && !defined(USE_CUDA_MINIMAL)) || defined(USE_ROCM) #ifdef _MSC_VER #pragma warning(disable : 4389) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index dedc01de9655..548f24e8ac69 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -263,9 +263,10 @@ void RunTest(const TestOptions& opts, } // namespace TEST(MatMulNBits, Float32) { + // onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); for (auto M : {1, 2, 100}) { - for (auto N : {1, 2, 32, 288}) { - for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto N : {/*2560, */ 1, 2, 32, 288}) { + for (auto K : {/*2560, */ 16, 32, 64, 128, 256, 1024, 93, 1234}) { for (auto block_size : {16, 32, 64, 128}) { for (auto accuracy_level : {0, 1, 4}) { TestOptions base_opts{}; diff --git a/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc b/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc index 3f298b0a8f8e..e780d35df08b 100644 --- a/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc +++ b/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc @@ -30,7 +30,8 @@ void TestNhwcConvOp(const NhwcConvOpAndTestAttributes& attributes, bool use_float16, bool weight_is_initializer = false) { int min_cuda_architecture = use_float16 ? 530 : 0; - bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + // NHWC implementation doesn't handle W in NHWC layout if it's not an initializer + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && weight_is_initializer; bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()); diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 4e9e80b180e9..43d3782be328 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -2078,6 +2078,7 @@ TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) { ASSERT_EQ(plan->allocation_plan[14].alloc_kind, AllocKind::kReuse) << "The input of reshape and gather will reuse the output of shape"; int gather_count = 0; + ASSERT_GT(plan->execution_plan.size(), 1) << "Number of execution plans should be greater than 1"; for (size_t i = 0; i < plan->execution_plan[1]->steps_.size(); i++) { if (strstr(typeid(*(plan->execution_plan[1]->steps_[i])).name(), "LaunchKernelStep")) { const Node* node = sess.GetSessionState().GetGraphViewer().GetNode(plan->execution_plan[1]->steps_[i]->GetNodeIndex()); diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index f553682975f1..c6d958536f48 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -74,7 +74,9 @@ static Ort::Session GetSessionObj(Ort::Env& env, T model_uri, int provider_type) if (provider_type == 1) { #ifdef USE_CUDA - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); std::cout << "Running simple inference with cuda provider" << std::endl; #else return Ort::Session(nullptr); diff --git a/onnxruntime/test/mlas/bench/bench_q4dq.cpp b/onnxruntime/test/mlas/bench/bench_q4dq.cpp index 9d15c9a6bf99..6d21ed2eef86 100644 --- a/onnxruntime/test/mlas/bench/bench_q4dq.cpp +++ b/onnxruntime/test/mlas/bench/bench_q4dq.cpp @@ -9,10 +9,10 @@ #include "core/util/thread_utils.h" static void BM_QDQBlockwiseQuantizer_QuantizeColumnwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N; auto src = RandomVectorUniform(M * N, -16.0f, 14.0f); @@ -37,10 +37,10 @@ static void BM_QDQBlockwiseQuantizer_QuantizeColumnwise(benchmark::State& state) } static void BM_MlasQuantizeBlockwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N; auto src = RandomVectorUniform(M * N, -16.0f, 14.0f); @@ -65,10 +65,10 @@ static void BM_MlasQuantizeBlockwise(benchmark::State& state) { } static void BM_QDQBlockwiseQuantizer_TransposeColumnwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); bool add8 = state.range(4) != 0; int quant_num_M = (M + quant_block_size - 1) / quant_block_size; int blob_size = (quant_block_size + 1) / 2; diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 354621eff42b..73c78b8cc3d4 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -53,6 +53,7 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, std::vector QuantBData(QuantBDataSizeInBytes); std::vector QuantBScale(QuantBScaleSize); std::vector QuantBZeroPoint(Symmetric ? 0 : QuantBZeroPointSizeInBytes); + bool has_zp_input = !Symmetric; MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), Symmetric ? nullptr : QuantBZeroPoint.data(), @@ -71,15 +72,17 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), + QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), tp.get()); } MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; params.A = A.data(); params.lda = K; - params.QuantBData = PackedQuantBData != nullptr - ? static_cast(PackedQuantBData.get()) - : static_cast(QuantBData.data()); + if (PackedQuantBData != nullptr) + params.QuantBDataWorkspace = static_cast(PackedQuantBData.get()); + else + params.QuantBDataWorkspace = static_cast(QuantBData.data()); params.QuantBScale = QuantBScale.data(); params.QuantBZeroPoint = Symmetric ? nullptr : QuantBZeroPoint.data(); params.Bias = HasBias ? Bias.data() : nullptr; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index f391027de4d5..0710981fa17c 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -55,8 +55,8 @@ class MlasSQNBitGemmTest : public MlasTestBase { size_t K, const float* A, size_t lda, - const void* QuantBData, - const void* PackedQuantBData, + const void* /*QuantBData*/, + const void* PackedQuantBDataWorkspace, const float* QuantBScale, const void* QuantBZeroPoint, const float* Bias, @@ -71,7 +71,12 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.Bias = Bias; params.C = C; params.ldc = ldc; - params.QuantBData = PackedQuantBData != nullptr ? PackedQuantBData : QuantBData; +#ifdef MLAS_TARGET_AMD64_IX86 + if (ComputeType == CompInt8) { + params.QuantBDataWorkspace = PackedQuantBDataWorkspace; + } +#endif + params.PackedQuantBData = static_cast(PackedQuantBDataWorkspace); params.QuantBScale = QuantBScale; params.QuantBZeroPoint = QuantBZeroPoint; params.PostProcessor = nullptr; @@ -213,12 +218,19 @@ class MlasSQNBitGemmTest : public MlasTestBase { auto print_matrix = [](size_t nrows, size_t ncols, const float* data) { for (size_t row = 0; row < nrows; ++row) { for (size_t col = 0; col < ncols; ++col) { - std::cout << data[row * ncols + col] << "\t"; + std::cout << data[row * ncols + col] << ", "; } std::cout << "\n"; } }; + auto print_matrix_col = [](size_t nrows, size_t ncols, size_t col, const float* data) { + for (size_t row = 0; row < nrows; ++row) { + std::cout << data[row * ncols + col] << ", "; + } + std::cout << "\n"; + }; + std::cout << "A:\n"; print_matrix(M, K, A); std::cout << "B:\n"; @@ -258,14 +270,25 @@ class MlasSQNBitGemmTest : public MlasTestBase { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } - void* PackedQuantBData = nullptr; + void* PackedQuantBDataWorkspace = nullptr; if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { - PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBData, + PackedQuantBDataWorkspace = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); + bool has_zp_input = QuantBZeroPoint != nullptr; + MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, + QuantBScale, has_zp_input, QuantBZeroPoint, GetMlasThreadPool()); } + CallGemm(M, N, K, + A, /* lda */ K, + QuantBData, PackedQuantBDataWorkspace, QuantBScale, QuantBZeroPoint, + Bias, + C, /* ldc */ N, + Workspace, + ComputeType, + Threadpool); + if (ComputeType == CompFp32) { CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); } else if (ComputeType == CompInt8) { @@ -275,15 +298,6 @@ class MlasSQNBitGemmTest : public MlasTestBase { << ComputeType << " (" << ComputeTypeName(ComputeType) << ")"; } - CallGemm(M, N, K, - A, /* lda */ K, - QuantBData, PackedQuantBData, QuantBScale, QuantBZeroPoint, - Bias, - C, /* ldc */ N, - Workspace, - ComputeType, - Threadpool); - size_t f = 0; for (size_t m = 0; m < M; m++) { for (size_t n = 0; n < N; n++, f++) { @@ -382,7 +396,6 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixtureAdd->Relu will be transformed to FusedConv -TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); - } - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Add"] == 1); - ASSERT_TRUE(op_to_count["Relu"] == 1); - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Add"] == 0); // Add removed from graph - ASSERT_TRUE(op_to_count["Relu"] == 0); // Relu removed from graph -} - -// Currently the ConvAddRelu fusion is only backed by a float kernel for the -// the CUDA EP. - -// When we see the corresponding pattern for the fp16 data type, the fusion -// should not be triggered as there is no kernel to back the fused pattern. - -// TODO(hasesh): Limit the test to using the CUDA EP for now as the level of -// data type support in other compatible EPs is still yet to be ascertained. - -// TODO(hasesh): If at all the fp16 type is supported for the fusion, adjust/remove -// this test. -TEST_F(GraphTransformationTests, FuseCudaConvAddRelu_UnsupportedType) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu_fp16.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); - } - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["Add"], 1); - ASSERT_EQ(op_to_count["Relu"], 1); - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - op_to_count = CountOpsInGraph(graph); - ASSERT_EQ(op_to_count["Add"], 1); // Add not removed from graph (fusion not triggered) - ASSERT_EQ(op_to_count["Relu"], 1); // Relu not removed from graph (fusion not triggered) -} - +#if !defined(DISABLE_CONTRIB_OPS) // Conv->Add->Relu will be left intact since there is Identity depend on Add TEST_F(GraphTransformationTests, FuseCudaConvAddReluIdentity) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu_identity.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); +#if defined(USE_JSEP) for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); + node.SetExecutionProviderType(kJsExecutionProvider); } +#else + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kCpuExecutionProvider); + } +#endif std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 1); ASSERT_TRUE(op_to_count["Relu"] == 1); @@ -2073,9 +2028,15 @@ TEST_F(GraphTransformationTests, FuseCudaConvAdd) { std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); +#if defined(USE_JSEP) for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); + node.SetExecutionProviderType(kJsExecutionProvider); } +#else + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kCpuExecutionProvider); + } +#endif std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 1); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -2165,17 +2126,13 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); -#ifdef USE_CUDA - for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); - } -#elif defined(USE_ROCM) +#if defined(USE_JSEP) for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); + node.SetExecutionProviderType(kJsExecutionProvider); } -#elif defined(USE_JSEP) +#else for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kJsExecutionProvider); + node.SetExecutionProviderType(kCpuExecutionProvider); } #endif std::map op_to_count_before_fusion = CountOpsInGraph(graph); @@ -2187,14 +2144,7 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count_after_fusion = CountOpsInGraph(graph); -#if defined(USE_CUDA) || defined(USE_ROCM) - std::set cuda_rocm_supported = {"Relu"}; - if (cuda_rocm_supported.find(model.second) == cuda_rocm_supported.end()) { - ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]); - } else { - ASSERT_EQ(op_to_count_after_fusion[model.second], 0); - } -#elif defined(USE_JSEP) +#if defined(USE_JSEP) std::set js_supported = {"Relu", "Clip", "Sigmoid", "Tanh", "LeakyRelu"}; if (js_supported.find(model.second) == js_supported.end()) { ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]); diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.cc b/onnxruntime/test/optimizer/graph_transform_test_builder.cc index 2cbfbbb31764..03a71868a3dc 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.cc @@ -246,14 +246,14 @@ Status TestGraphTransformer(const std::function& ORT_RETURN_IF_ERROR(pre_graph_checker(graph)); } #if SAVE_TEST_GRAPH - ORT_RETURN_IF_ERROR(Model::Save(model, "model_original.onnx")); + ORT_RETURN_IF_ERROR(Model::Save(model, ToPathString("model_original.onnx"))); #endif ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, level, logger)); if (post_graph_checker) { ORT_RETURN_IF_ERROR(post_graph_checker(graph)); } #if SAVE_TEST_GRAPH - ORT_RETURN_IF_ERROR(Model::Save(model, "model_optimized.onnx")); + ORT_RETURN_IF_ERROR(Model::Save(model, ToPathString("model_optimized.onnx"))); #endif }; diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 14c5b60d6e0b..a043d6553bdf 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -12,6 +12,7 @@ #include "core/mlas/inc/mlas.h" #include "core/optimizer/double_qdq_pairs_remover.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" +#include "core/optimizer/qdq_transformer/qdq_propagation.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" @@ -979,6 +980,52 @@ TEST(QDQTransformerTests, ReshapeDropQDQ) { RunReshapeDropQDQTestCase({1, 3, 2, 2}, {1, 12}, false, 21); // Use int16 ONNX QDQ ops } +// Runs a test case that checks if Q/DQ nodes are *not* dropped from DQ -> MaxPool -> Q if the quantization scale is +// negative. +template +static void RunMaxPoolNegativeScaleDropQDQTestCase() { + auto build_test_case = [](ModelTestBuilder& builder) { + constexpr QuantType qmin = std::numeric_limits::min(); + constexpr QuantType qmax = std::numeric_limits::max(); + + const std::vector input_shape = {1, 17, 17, 3}; + auto* input_arg = builder.MakeInput(input_shape, qmin, qmax); + auto* output_arg = builder.MakeOutput(); + + constexpr float scale = -0.003f; + QuantType zero_point = 1 + (qmax + qmin) / 2; + + auto* input_arg_dq = builder.MakeIntermediate(); + auto* maxpool_output = builder.MakeIntermediate(); + + builder.AddDequantizeLinearNode(input_arg, scale, zero_point, input_arg_dq); + + Node& maxpool_node = builder.AddNode("MaxPool", {input_arg_dq}, {maxpool_output}); + maxpool_node.AddAttribute("auto_pad", "VALID"); + maxpool_node.AddAttribute("kernel_shape", std::vector({2, 2})); + + builder.AddQuantizeLinearNode(maxpool_output, scale, zero_point, output_arg); + }; + + auto check_graph = [](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["MaxPool"], 1); + EXPECT_EQ(op_to_count["QuantizeLinear"], 1); + EXPECT_EQ(op_to_count["DequantizeLinear"], 1); + }; + + constexpr int opset = 21; + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, opset); +} + +// Checks that Q/DQ nodes are *not* dropped from DQ -> MaxPool -> Q for negative scale. Uses 8-bit and 16-bit Q/DQ ops. +TEST(QDQTransformerTests, MaxpoolDontDropQDQForNegativeScale) { + RunMaxPoolNegativeScaleDropQDQTestCase(); + RunMaxPoolNegativeScaleDropQDQTestCase(); + RunMaxPoolNegativeScaleDropQDQTestCase(); + RunMaxPoolNegativeScaleDropQDQTestCase(); +} + // Runs a test case that checks if Q/DQ nodes are dropped from DQ -> (Un)Squeeze -> Q. template static void RunSqueezeUnsqueezeDropQDQTestCase(const std::string& squeeze_type, @@ -3084,6 +3131,57 @@ TEST(QDQTransformerTests, QDQPropagation_QBackward) { #endif } +// Test backwards propagation of a QuantizeLinear node that uses the "output_dtype" attribute +// to set the quantization type (i.e., does not have an explicit zero-point input). This tests +// the copying of attributes for QDQ propagation. +TEST(QDQTransformerTests, QDQPropagation_QBackward_NoZP_OutputDtypeAttribute) { + auto test_case = [&](ONNX_NAMESPACE::TensorProto_DataType q_output_type) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 2, 2}, {-2.0f, 0.0f, 1.0f, 2.0f}); + auto* output_arg = builder.MakeOutput(); + + // add Add + auto* const_1_input = builder.MakeScalarInitializer(1.0f); + auto* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {input_arg, const_1_input}, {add_output}); + + // add Transpose + auto* transpose_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {add_output}, {transpose_output}); + + // add Q with a "output_dtype" attribute. Omit the zero-point input (defaults to 0). + constexpr float qdq_scale = 1.0f; + Node& q_node = builder.AddQuantizeLinearNode(transpose_output, qdq_scale, output_arg); + q_node.AddAttribute("output_dtype", static_cast(q_output_type)); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + std::vector expected_op_types_in_order = { + "Add", + qdq_keys.quantize_linear, + qdq_keys.dequantize_linear, + "Transpose", + qdq_keys.quantize_linear, + }; + + const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph(), true); + EXPECT_EQ(op_types_in_order, expected_op_types_in_order); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + 21); // Opset >= 21 supports the "output_dtype" attribute + }; + + test_case(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + test_case(ONNX_NAMESPACE::TensorProto_DataType_INT8); + test_case(ONNX_NAMESPACE::TensorProto_DataType_UINT16); + test_case(ONNX_NAMESPACE::TensorProto_DataType_INT16); +} + TEST(QDQTransformerTests, QDQPropagation_DQForward) { auto test_case = [&](const std::vector& input_shape, size_t maxpool_dim, @@ -3420,6 +3518,122 @@ TEST(QDQTransformerTests, QDQPropagation_DQ_Q) { #endif } +// Test propagating a DQ forward through a chain of Slice and Transpose operators that have multiple consumers. +// original model: +// in0 -> DQ -> Slice --+--> slice_out +// | +// +--> Add -> out0 +// | +// +--> Transpose --+--> Pow -> out1 +// | | +// | +--> Pow -> out2 +// | +// +--> Transpose --+--> Pow -> out3 +// | +// +--> Pow -> out4 +// expected model: +// in0 -> DQ -> Slice -> Q --+--> DQ -> slice_out +// | +// +--> DQ -> Add -> out0 +// | +// +--> DQ -> TP -> Q --+--> DQ -> Pow -> out1 +// | | +// | +--> DQ -> Pow -> out2 +// | +// +--> DQ -> TP -> Q --+--> DQ -> Pow -> out3 +// | +// +--> DQ -> Pow -> out4 +TEST(QDQTransformerTests, QDQPropagation_DQForward_SliceMultipleConsumers) { + auto run_test_case = [&](bool slice_has_graph_output) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector input0_shape = {1, 2, 2, 2}; + std::vector input1_shape = {1, 1, 1, 1}; + auto* input0_arg = builder.MakeInput(input0_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* input1_arg = builder.MakeInput(input1_shape, {0.0f}); + auto* output0_arg = builder.MakeOutput(); + auto* output1_arg = builder.MakeOutput(); + auto* output2_arg = builder.MakeOutput(); + auto* output3_arg = builder.MakeOutput(); + auto* output4_arg = builder.MakeOutput(); + + // DQ + constexpr float qdq_scale = 1.0f; + constexpr uint8_t qdq_zero_point = 128; + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input0_arg, qdq_scale, qdq_zero_point, dq_output); + + // Slice + auto* slice_output = slice_has_graph_output ? builder.MakeOutput() : builder.MakeIntermediate(); + auto* slice_starts = builder.Make1DInitializer(std::vector{0, 0, 0, 0}); + auto* slice_ends = builder.Make1DInitializer(std::vector{1, 1, 1, 1}); + builder.AddNode("Slice", {dq_output, slice_starts, slice_ends}, {slice_output}); + + // Add + builder.AddNode("Add", {slice_output, input1_arg}, {output0_arg}); + + // Transpose + auto* transpose0_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {slice_output}, {transpose0_output}); + + // Transpose + auto* transpose1_output = builder.MakeIntermediate(); + builder.AddNode("Transpose", {slice_output}, {transpose1_output}); + + // Pows + auto* pow_exp = builder.MakeScalarInitializer(2.0f); + builder.AddNode("Pow", {transpose0_output, pow_exp}, {output1_arg}); + builder.AddNode("Pow", {transpose0_output, pow_exp}, {output2_arg}); + builder.AddNode("Pow", {transpose1_output, pow_exp}, {output3_arg}); + builder.AddNode("Pow", {transpose1_output, pow_exp}, {output4_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + std::vector expected_op_types_in_order; + expected_op_types_in_order.reserve(20); + expected_op_types_in_order.insert(expected_op_types_in_order.end(), + {qdq_keys.dequantize_linear, + "Slice", + qdq_keys.quantize_linear}); + + if (slice_has_graph_output) { + // Should have a DQ before the graph output generated by the Slice. + expected_op_types_in_order.push_back(qdq_keys.dequantize_linear); + } + + expected_op_types_in_order.insert(expected_op_types_in_order.end(), + {qdq_keys.dequantize_linear, + "Add", + qdq_keys.dequantize_linear, + "Transpose", + qdq_keys.quantize_linear, qdq_keys.dequantize_linear, + "Pow", + qdq_keys.dequantize_linear, + "Pow", + qdq_keys.dequantize_linear, + "Transpose", + qdq_keys.quantize_linear, qdq_keys.dequantize_linear, + "Pow", + qdq_keys.dequantize_linear, + "Pow"}); + + const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph(), true); + EXPECT_EQ(op_types_in_order, expected_op_types_in_order); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + 18, 0.0, 0.0, std::make_unique()); + }; + + run_test_case(/*slice_has_graph_output*/ false); + run_test_case(/*slice_has_graph_output*/ true); +} + TEST(QDQTransformerTests, QDQ_Selector_Test) { const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/transform/qdq_conv.onnx"); @@ -3525,7 +3739,8 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { const auto compute_capability = utils::MakeComputeCapability( whole_graph_viewer, nodes, []() { return "sub_graph"; }, - "Test Provider"); + "Test Provider", + /*drop_constant_initializers*/ false); const GraphViewer partial_graph_viewer(graph, *compute_capability->sub_graph); ASSERT_EQ(3, partial_graph_viewer.NumberOfNodes()); diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index e6d4e0a94abd..84c3bc16346f 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -253,7 +253,6 @@ static bool ParseSessionConfigs(const std::string& configs_string, test_config.machine_config.provider_type_name = onnxruntime::kDnnlExecutionProvider; } else if (!CompareCString(optarg, ORT_TSTR("openvino"))) { test_config.machine_config.provider_type_name = onnxruntime::kOpenVINOExecutionProvider; - test_config.run_config.optimization_level = ORT_DISABLE_ALL; } else if (!CompareCString(optarg, ORT_TSTR("tensorrt"))) { test_config.machine_config.provider_type_name = onnxruntime::kTensorrtExecutionProvider; } else if (!CompareCString(optarg, ORT_TSTR("qnn"))) { diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 72b5da7aaec9..fc1bdb10d745 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -699,6 +699,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); std::set deprecated_device_types = {"CPU_FP32", "GPU_FP32", "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", "GPU.0_FP16", "GPU.1_FP16"}; + size_t num_gpus = 10; + for (size_t i = 0; i <= num_gpus; i++) { + ov_supported_device_types.emplace("GPU." + std::to_string(i)); + } if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) { ov_options[key] = value; } else if (deprecated_device_types.find(value) != deprecated_device_types.end()) { diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index d0e08448ce45..182fa4729a88 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -25,7 +25,15 @@ struct DefaultTolerance { static constexpr float relative = 1e-5f; // Allow to have different default absolute tolerance for different providers. - static float get_absolute(const std::string& /*provider_type*/) { + static float get_absolute(const std::string& provider_type /*provider_type*/) { + if (provider_type == kOpenVINOExecutionProvider) { +#ifdef OPENVINO_CONFIG_NPU + return 0.005f; +#else + return absolute; +#endif + } + return absolute; } }; @@ -40,7 +48,15 @@ struct DefaultTolerance { static constexpr float relative = 1e-4f; - static float get_absolute(const std::string& /*provider_type*/) { + static float get_absolute(const std::string& provider_type /*provider_type*/) { + if (provider_type == kOpenVINOExecutionProvider) { +#ifdef OPENVINO_CONFIG_NPU + return 0.005f; +#else + return absolute; +#endif + } + return absolute; } }; @@ -411,7 +427,7 @@ struct TensorCheck { for (int64_t i = 0; i < size; ++i) { if (std::isnan(f_expected[i])) { - EXPECT_TRUE(std::isnan(f_expected[i])) << "Expected NaN. i:" << i; + EXPECT_TRUE(std::isnan(f_actual[i])) << "Expected NaN. i:" << i; } else if (std::isinf(f_expected[i])) { // Test infinity for equality EXPECT_EQ(f_expected[i], f_actual[i]) << "Expected infinity. i:" << i; } else { diff --git a/onnxruntime/test/providers/cpu/generator/random_test.cc b/onnxruntime/test/providers/cpu/generator/random_test.cc index ec9b1614488a..f42f32d63d1f 100644 --- a/onnxruntime/test/providers/cpu/generator/random_test.cc +++ b/onnxruntime/test/providers/cpu/generator/random_test.cc @@ -178,7 +178,7 @@ TEST(Random, InvalidDType) { test.AddAttribute("shape", dims); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomNormal) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } { @@ -194,7 +194,7 @@ TEST(Random, InvalidDType) { test.AddAttribute("shape", dims); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomUniform) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } { @@ -210,7 +210,7 @@ TEST(Random, InvalidDType) { test.AddInput("X", dims, input); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomNormalLike) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } { @@ -226,7 +226,7 @@ TEST(Random, InvalidDType) { test.AddInput("X", dims, input); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomUniformLike) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } } diff --git a/onnxruntime/test/providers/cpu/math/clip_test.cc b/onnxruntime/test/providers/cpu/math/clip_test.cc index 6f81bbbe31d5..9948a6cc8a68 100644 --- a/onnxruntime/test/providers/cpu/math/clip_test.cc +++ b/onnxruntime/test/providers/cpu/math/clip_test.cc @@ -119,6 +119,24 @@ TEST(MathOpTest, Clip_Default_uint64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +TEST(MathOpTest, Clip_MLFloat16) { + OpTester test("Clip", 12); + + std::vector dims{3, 3}; + test.AddInput("X", dims, + {MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(-3.0f), + MLFloat16(-4.0f), MLFloat16(0.0f), MLFloat16(2.0f), + MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(8.0f)}); + test.AddInput("min", {}, {MLFloat16(0.0f)}); + test.AddInput("max", {}, {MLFloat16(6.0f)}); + test.AddOutput("Y", dims, + {MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f), + MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(2.0f), + MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(6.0f)}); + + test.Run(); +} + TEST(MathOpTest, Clip_int32) { OpTester test("Clip", 12); diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index eb3575f2cde8..bd3d21d4929f 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1553,6 +1553,47 @@ TEST(MathOpTest, Min_12_Float_Nan) { } } +TEST(MathOpTest, Min_12_Float_Nan_with_scalar) { + OpTester test("Min", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {0.25f}); + test.AddOutput("min", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.25f}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Min_12_Float_with_scalar_Nan) { + OpTester test("Min", 12); + test.AddInput("data_1", {2, 2}, + {0.25f, -0.25f, -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("min", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Min_12_Double) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, @@ -1586,12 +1627,53 @@ TEST(MathOpTest, Min_12_Double_Nan) { std::numeric_limits::quiet_NaN(), -1.0, -1.0, -2.0, 0.5, 0.0, 1.0}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Min_12_Double_Nan_with_scalar) { + OpTester test("Min", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.5}); + test.AddInput("data_2", {1}, {0.25}); + test.AddOutput("min", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.25}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Min_12_Double_with_scalar_Nan) { + OpTester test("Min", 12); + test.AddInput("data_1", {2, 2}, + {0.25, -0.25, -0.5, 0.5}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("min", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -1666,7 +1748,7 @@ TEST(MathOpTest, Min_12_UInt64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16) { +TEST(MathOpTest, Min_12_MLFloat16) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({1.f, 1.f, 1.f})); @@ -1679,7 +1761,7 @@ TEST(MathOpTest, Min_12_MLFLoat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16_Scalar0) { +TEST(MathOpTest, Min_12_MLFloat16_Scalar0) { OpTester test("Min", 12); test.AddInput("data_0", {}, MakeMLFloat16({-10.f})); @@ -1692,7 +1774,7 @@ TEST(MathOpTest, Min_12_MLFLoat16_Scalar0) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16_Scalar1) { +TEST(MathOpTest, Min_12_MLFloat16_Scalar1) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({2.f, 3.f, 4.f})); @@ -1809,12 +1891,53 @@ TEST(MathOpTest, Max_12_Float_Nan) { std::numeric_limits::quiet_NaN(), -0.5f, 0.0f, -1.0f, 1.0f, 1.0f, 2.0f}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Float_Nan_with_scalar) { + OpTester test("Max", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {0.25f}); + test.AddOutput("max", {3, 1}, + {std::numeric_limits::quiet_NaN(), 0.25f, 0.5f}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Float_with_scalar_Nan) { + OpTester test("Max", 12); + test.AddInput("data_1", {2, 2}, + {0.25f, -0.25f, -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("max", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -1854,12 +1977,53 @@ TEST(MathOpTest, Max_12_Double_Nan) { std::numeric_limits::quiet_NaN(), -0.5, 0.0, -1.0, 1.0, 1.0, 2.0}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Double_Nan_with_scalar) { + OpTester test("Max", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.5}); + test.AddInput("data_2", {1}, {0.25}); + test.AddOutput("max", {3, 1}, + {std::numeric_limits::quiet_NaN(), 0.25, 0.5}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Double_with_scalar_Nan) { + OpTester test("Max", 12); + test.AddInput("data_1", {2, 2}, + {0.25, -0.25, -0.5, 0.5}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("max", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -1934,7 +2098,7 @@ TEST(MathOpTest, Max_12_UInt64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16) { +TEST(MathOpTest, Max_12_MLFloat16) { OpTester test("Max", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({-1.f, -1.f, -1.f})); @@ -1947,7 +2111,7 @@ TEST(MathOpTest, Max_12_MLFLoat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16_Scalar0) { +TEST(MathOpTest, Max_12_MLFloat16_Scalar0) { OpTester test("Max", 12); test.AddInput("data_0", {}, MakeMLFloat16({-1.f})); @@ -1960,7 +2124,7 @@ TEST(MathOpTest, Max_12_MLFLoat16_Scalar0) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16_Scalar1) { +TEST(MathOpTest, Max_12_MLFloat16_Scalar1) { OpTester test("Max", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({-1.f, -2.f, -3.f})); diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index cb5fc8095982..95b274966fbb 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -714,6 +714,241 @@ TEST(ConvFp16Test, Conv2D_group) { TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvFp16Test, Depthwise2D_Bias_Group1_Issue18992) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = {MLFloat16(1.0f)}; + vector X_shape = {1, 1, 1, 1}; + vector W = {MLFloat16(0.5f)}; + vector W_shape = {1, 1, 1, 1}; + vector B = {MLFloat16(0.5f)}; + vector B_shape = {1}; + vector Y_shape = {1, 1, 1, 1}; + auto expected_vals = {MLFloat16(1.0f)}; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvFp16Test, Depthwise2D_Bias_Group2) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 2, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(2.0f), + MLFloat16(3.0f), MLFloat16(4.0f), MLFloat16(5.0f), + MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), + + MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f), + MLFloat16(12.0f), MLFloat16(13.0f), MLFloat16(14.0f), + MLFloat16(15.0f), MLFloat16(16.0f), MLFloat16(17.0f)}; + vector X_shape = {1, 2, 3, 3}; + vector W = {MLFloat16(1.0f), MLFloat16(2.0f)}; + vector W_shape = {2, 1, 1, 1}; + vector B = {MLFloat16(1.0f), MLFloat16(-1.0f)}; + vector B_shape = {2}; + vector Y_shape = {1, 2, 3, 3}; + auto expected_vals = { + MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), + MLFloat16(4.0f), MLFloat16(5.0f), MLFloat16(6.0f), + MLFloat16(7.0f), MLFloat16(8.0f), MLFloat16(9.0f), + + MLFloat16(17.0f), MLFloat16(19.0f), MLFloat16(21.0f), + MLFloat16(23.0f), MLFloat16(25.0f), MLFloat16(27.0f), + MLFloat16(29.0f), MLFloat16(31.0f), MLFloat16(33.0f)}; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvFp16Test, Depthwise2D_Bias_Group15) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 15, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + // C = 0 + MLFloat16(0.0f), MLFloat16(1.0f), + MLFloat16(2.0f), MLFloat16(3.0f), + + // C = 1 + MLFloat16(4.0f), MLFloat16(5.0f), + MLFloat16(6.0f), MLFloat16(7.0f), + + // C = 2 + MLFloat16(8.0f), MLFloat16(9.0f), + MLFloat16(10.0f), MLFloat16(11.0f), + + // C = 3 + MLFloat16(12.0f), MLFloat16(13.0f), + MLFloat16(14.0f), MLFloat16(15.0f), + + // C = 4 + MLFloat16(16.0f), MLFloat16(17.0f), + MLFloat16(18.0f), MLFloat16(19.0f), + + // C = 5 + MLFloat16(20.0f), MLFloat16(21.0f), + MLFloat16(22.0f), MLFloat16(23.0f), + + // C = 6 + MLFloat16(24.0f), MLFloat16(25.0f), + MLFloat16(26.0f), MLFloat16(27.0f), + + // C = 7 + MLFloat16(28.0f), MLFloat16(29.0f), + MLFloat16(30.0f), MLFloat16(31.0f), + + // C = 8 + MLFloat16(32.0f), MLFloat16(33.0f), + MLFloat16(34.0f), MLFloat16(35.0f), + + // C = 9 + MLFloat16(36.0f), MLFloat16(37.0f), + MLFloat16(38.0f), MLFloat16(39.0f), + + // C = 10 + MLFloat16(40.0f), MLFloat16(41.0f), + MLFloat16(42.0f), MLFloat16(43.0f), + + // C = 11 + MLFloat16(44.0f), MLFloat16(45.0f), + MLFloat16(46.0f), MLFloat16(47.0f), + + // C = 12 + MLFloat16(48.0f), MLFloat16(49.0f), + MLFloat16(50.0f), MLFloat16(51.0f), + + // C = 13 + MLFloat16(52.0f), MLFloat16(53.0f), + MLFloat16(54.0f), MLFloat16(55.0f), + + // C = 14 + MLFloat16(56.0f), MLFloat16(57.0f), + MLFloat16(58.0f), MLFloat16(59.0f)}; + vector X_shape = {1, 15, 2, 2}; + vector W = { + // M = 0 + MLFloat16(0.0f), MLFloat16(1.0f), + MLFloat16(2.0f), MLFloat16(3.0f), + + // M = 1 + MLFloat16(4.0f), MLFloat16(5.0f), + MLFloat16(6.0f), MLFloat16(7.0f), + + // M = 2 + MLFloat16(8.0f), MLFloat16(9.0f), + MLFloat16(10.0f), MLFloat16(11.0f), + + // M = 3 + MLFloat16(12.0f), MLFloat16(13.0f), + MLFloat16(14.0f), MLFloat16(15.0f), + + // M = 4 + MLFloat16(16.0f), MLFloat16(17.0f), + MLFloat16(18.0f), MLFloat16(19.0f), + + // M = 5 + MLFloat16(20.0f), MLFloat16(21.0f), + MLFloat16(22.0f), MLFloat16(23.0f), + + // M = 6 + MLFloat16(24.0f), MLFloat16(25.0f), + MLFloat16(26.0f), MLFloat16(27.0f), + + // M = 7 + MLFloat16(28.0f), MLFloat16(29.0f), + MLFloat16(30.0f), MLFloat16(31.0f), + + // M = 8 + MLFloat16(32.0f), MLFloat16(33.0f), + MLFloat16(34.0f), MLFloat16(35.0f), + + // M = 9 + MLFloat16(36.0f), MLFloat16(37.0f), + MLFloat16(38.0f), MLFloat16(39.0f), + + // M = 10 + MLFloat16(40.0f), MLFloat16(41.0f), + MLFloat16(42.0f), MLFloat16(43.0f), + + // M = 11 + MLFloat16(44.0f), MLFloat16(45.0f), + MLFloat16(46.0f), MLFloat16(47.0f), + + // M = 12 + MLFloat16(48.0f), MLFloat16(49.0f), + MLFloat16(50.0f), MLFloat16(51.0f), + + // M = 13 + MLFloat16(52.0f), MLFloat16(53.0f), + MLFloat16(54.0f), MLFloat16(55.0f), + + // M = 14 + MLFloat16(56.0f), MLFloat16(57.0f), + MLFloat16(58.0f), MLFloat16(59.0f)}; + vector W_shape = {15, 1, 2, 2}; + vector B = { + MLFloat16(101.0f), + MLFloat16(102.0f), + MLFloat16(103.0f), + MLFloat16(104.0f), + MLFloat16(105.0f), + MLFloat16(106.0f), + MLFloat16(107.0f), + MLFloat16(108.0f), + MLFloat16(109.0f), + MLFloat16(110.0f), + MLFloat16(111.0f), + MLFloat16(112.0f), + MLFloat16(113.0f), + MLFloat16(114.0f), + MLFloat16(115.0f)}; + vector B_shape = {15}; + vector Y_shape = {1, 15, 1, 1}; + auto expected_vals = { + MLFloat16(115.0f), // 0.0*0.0 + 1.0*1.0 + 2.0*2.0 + 3.0*3.0 + 101.0 + MLFloat16(228.0f), + MLFloat16(469.0f), + MLFloat16(838.0f), + MLFloat16(1335.0f), + MLFloat16(1960.0f), + MLFloat16(2713.0f), // 24.0*24.0 + 25.0*25.0 + 26.0*26.0 + 27.0*27.0 + 107.0 + MLFloat16(3594.0f), + MLFloat16(4603.0f), + MLFloat16(5740.0f), + MLFloat16(7005.0f), + MLFloat16(8398.0f), + MLFloat16(9919.0f), // 48.0*48.0 + 49.0*49.0 + 50.0*50.0 + 51.0*51.0 + 113.0 + MLFloat16(11568.0f), // 52.0*52.0 + 53.0*53.0 + 54.0*54.0 + 55.0*55.0 + 114.0 + MLFloat16(13345.0f) // 56.0*56.0 + 57.0*57.0 + 58.0*58.0 + 59.0*59.0 + 115.0 + }; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + TEST(ConvFp16Test, ConvDimWithZero) { ConvOpAndTestAttributes attrs = { "", // auto_pad @@ -1074,4 +1309,4 @@ TEST(ConvFp16Test, SharedPrepackedWeights) { } // namespace test } // namespace onnxruntime -#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED \ No newline at end of file +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index 0efa78af2795..25caa732efa2 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -25,6 +25,7 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, const std::initializer_list& expected_output, const vector& expected_output_shape, bool weight_is_initializer = false, + optional epsilon = optional(), OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", int opset = 7) { @@ -56,11 +57,13 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, test.AddOutput("Y", expected_output_shape, expected_output); + if (epsilon.has_value()) { + test.SetOutputTolerance(*epsilon); + } + std::unordered_set excluded_providers(attributes.excluded_providers); // Disable TensorRT because weight as input is not supported excluded_providers.insert(kTensorrtExecutionProvider); - // Disable CUDA NHWC execution provider as it is currently flaky - excluded_providers.insert(kCudaNHWCExecutionProvider); // QNN SDK 2.10.0 has a bug that breaks support for dynamic bias inputs. excluded_providers.insert(kQnnExecutionProvider); @@ -189,10 +192,15 @@ TEST(ConvTest, Conv1D_Bias) { vector Y_shape = {2, 1, 4}; auto expected_vals = {0.37892162799835205f, 0.4625728130340576f, 0.4934738576412201f, 0.44801419973373413f, 0.37892162799835205f, 0.2499445676803589f, 0.31682088971138f, 0.32773756980895996f}; - TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + + // For the CUDA EP: Due to CUDNN Frontend using TF32 for FP32 operations we get a higher error than using FP32 only, + // as TF32 has a 10 bit mantissa. + float epsilon = 1.1e-5f; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, epsilon); // CoreML EP requires weight to be an initializer - TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true, epsilon); } // Conv47 @@ -240,7 +248,7 @@ TEST(ConvTest, Conv1D_Invalid_Input_Shape) { vector X_shape = {1, 1, 1}; vector dummy_shape = {1, 1, 2}; auto dummy_vals = {0.0f, 0.0f}; - TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, + TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, optional(), OpTester::ExpectResult::kExpectFailure, "Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. " "Both inferred and declared dimension have values but they differ. Inferred=0 Declared=2 Dimension=2", @@ -263,7 +271,7 @@ TEST(ConvTest, Conv2D_Invalid_Input_Shape) { vector dummy_shape = {2, 2, 1, 2}; auto dummy_vals = {-0.0f, 0.0f, -0.0f, -0.0f, -0.0f, 0.0f, -0.0f, -0.0f}; - TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, + TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, optional(), OpTester::ExpectResult::kExpectFailure, "Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. " "Both inferred and declared dimension have values but they differ. Inferred=1 Declared=2 Dimension=0", @@ -620,7 +628,12 @@ TEST(ConvTest, Conv3D_Bias) { -0.47542816400527954f, -0.5078460574150085f, -0.4205915927886963f, -0.5584549903869629f, -0.39770257472991943f, -0.45317384600639343f, -0.5598302483558655f, -0.2542789578437805f, -0.5359901785850525f, -0.48090484738349915f, -0.38603779673576355f, -0.4991581439971924f}; - TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + + // For the CUDA EP: Due to CUDNN Frontend using TF32 for FP32 operations we get a higher error than using FP32 only, + // as TF32 has a 10 bit mantissa. + float epsilon = 2.1e-4f; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, epsilon); } TEST(ConvTest, Conv2D_group) { @@ -647,6 +660,241 @@ TEST(ConvTest, Conv2D_group) { TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvTest, Depthwise2D_Bias_Group1_Issue18992) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = {1.0f}; + vector X_shape = {1, 1, 1, 1}; + vector W = {0.5f}; + vector W_shape = {1, 1, 1, 1}; + vector B = {0.5f}; + vector B_shape = {1}; + vector Y_shape = {1, 1, 1, 1}; + auto expected_vals = {1.0f}; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvTest, Depthwise2D_Bias_Group2) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 2, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + 0.0f, 1.0f, 2.0f, + 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, + + 9.0f, 10.0f, 11.0f, + 12.0f, 13.0f, 14.0f, + 15.0f, 16.0f, 17.0f}; + vector X_shape = {1, 2, 3, 3}; + vector W = {1.0f, 2.0f}; + vector W_shape = {2, 1, 1, 1}; + vector B = {1.0f, -1.0f}; + vector B_shape = {2}; + vector Y_shape = {1, 2, 3, 3}; + auto expected_vals = { + 1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, + + 17.0f, 19.0f, 21.0f, + 23.0f, 25.0f, 27.0f, + 29.0f, 31.0f, 33.0f}; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvTest, Depthwise2D_Bias_Group15) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 15, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + // C = 0 + 0.0f, 1.0f, + 2.0f, 3.0f, + + // C = 1 + 4.0f, 5.0f, + 6.0f, 7.0f, + + // C = 2 + 8.0f, 9.0f, + 10.0f, 11.0f, + + // C = 3 + 12.0f, 13.0f, + 14.0f, 15.0f, + + // C = 4 + 16.0f, 17.0f, + 18.0f, 19.0f, + + // C = 5 + 20.0f, 21.0f, + 22.0f, 23.0f, + + // C = 6 + 24.0f, 25.0f, + 26.0f, 27.0f, + + // C = 7 + 28.0f, 29.0f, + 30.0f, 31.0f, + + // C = 8 + 32.0f, 33.0f, + 34.0f, 35.0f, + + // C = 9 + 36.0f, 37.0f, + 38.0f, 39.0f, + + // C = 10 + 40.0f, 41.0f, + 42.0f, 43.0f, + + // C = 11 + 44.0f, 45.0f, + 46.0f, 47.0f, + + // C = 12 + 48.0f, 49.0f, + 50.0f, 51.0f, + + // C = 13 + 52.0f, 53.0f, + 54.0f, 55.0f, + + // C = 14 + 56.0f, 57.0f, + 58.0f, 59.0f}; + vector X_shape = {1, 15, 2, 2}; + vector W = { + // M = 0 + 0.0f, 1.0f, + 2.0f, 3.0f, + + // M = 1 + 4.0f, 5.0f, + 6.0f, 7.0f, + + // M = 2 + 8.0f, 9.0f, + 10.0f, 11.0f, + + // M = 3 + 12.0f, 13.0f, + 14.0f, 15.0f, + + // M = 4 + 16.0f, 17.0f, + 18.0f, 19.0f, + + // M = 5 + 20.0f, 21.0f, + 22.0f, 23.0f, + + // M = 6 + 24.0f, 25.0f, + 26.0f, 27.0f, + + // M = 7 + 28.0f, 29.0f, + 30.0f, 31.0f, + + // M = 8 + 32.0f, 33.0f, + 34.0f, 35.0f, + + // M = 9 + 36.0f, 37.0f, + 38.0f, 39.0f, + + // M = 10 + 40.0f, 41.0f, + 42.0f, 43.0f, + + // M = 11 + 44.0f, 45.0f, + 46.0f, 47.0f, + + // M = 12 + 48.0f, 49.0f, + 50.0f, 51.0f, + + // M = 13 + 52.0f, 53.0f, + 54.0f, 55.0f, + + // M = 14 + 56.0f, 57.0f, + 58.0f, 59.0f}; + vector W_shape = {15, 1, 2, 2}; + vector B = { + 101.0f, + 102.0f, + 103.0f, + 104.0f, + 105.0f, + 106.0f, + 107.0f, + 108.0f, + 109.0f, + 110.0f, + 111.0f, + 112.0f, + 113.0f, + 114.0f, + 115.0f}; + vector B_shape = {15}; + vector Y_shape = {1, 15, 1, 1}; + auto expected_vals = { + 115.0f, // 0.0*0.0 + 1.0*1.0 + 2.0*2.0 + 3.0*3.0 + 101.0 + 228.0f, + 469.0f, + 838.0f, + 1335.0f, + 1960.0f, + 2713.0f, // 24.0*24.0 + 25.0*25.0 + 26.0*26.0 + 27.0*27.0 + 107.0 + 3594.0f, + 4603.0f, + 5740.0f, + 7005.0f, + 8398.0f, + 9919.0f, // 48.0*48.0 + 49.0*49.0 + 50.0*50.0 + 51.0*51.0 + 113.0 + 11568.0f, // 52.0*52.0 + 53.0*53.0 + 54.0*54.0 + 55.0*55.0 + 114.0 + 13345.0f // 56.0*56.0 + 57.0*57.0 + 58.0*58.0 + 59.0*59.0 + 115.0 + }; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + TEST(ConvTest, ConvDimWithZero) { ConvOpAndTestAttributes attrs = { "", // auto_pad @@ -667,7 +915,8 @@ TEST(ConvTest, ConvDimWithZero) { // not handled by ACL attrs.excluded_providers.insert(kAclExecutionProvider); - TestConvOp(attrs, {X, W}, {X_shape, W_shape}, {}, out_shape, false, OpTester::ExpectResult::kExpectSuccess, "", 10); + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, {}, out_shape, false, optional(), + OpTester::ExpectResult::kExpectSuccess, "", 10); } TEST(ConvTest, Conv1D_asymmetric_padding) { diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 98a65b8efffd..ab24337046b9 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -263,6 +263,22 @@ TEST(ReductionOpTest, ReduceL1) { test.Run(); } +TEST(ReductionOpTest, ReduceL1_double) { + OpTester test("ReduceL1"); + test.AddAttribute("axes", std::vector{0, 2}); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + + 5.0f, 6.0f, + 7.0f, 8.0f, + + 9.0f, 10.0f, + 11.0f, 12.0f}); + test.AddOutput("reduced", {1, 2, 1}, {33.0f, 45.0f}); + test.Run(); +} + TEST(ReductionOpTest, ReduceL1_int32) { OpTester test("ReduceL1"); test.AddAttribute("axes", std::vector{0, 2}); @@ -423,6 +439,23 @@ TEST(ReductionOpTest, ReduceL2) { test.Run(); } +TEST(ReductionOpTest, ReduceL2_double) { + OpTester test("ReduceL2"); + test.AddAttribute("axes", std::vector{0, 2}); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + + 5.0f, 6.0f, + 7.0f, 8.0f, + + 9.0f, 10.0f, + 11.0f, 12.0f}); + test.AddOutput("reduced", {2}, {15.71623325f, 20.07485962f}); + test.Run(); +} + #if defined(USE_DNNL) TEST(ReductionOpTest, ReduceL2_bfloat16) { #ifdef USE_DNNL @@ -512,6 +545,25 @@ TEST(ReductionOpTest, ReduceLogSum) { test.Run(); } +TEST(ReductionOpTest, ReduceLogSum_double) { + OpTester test("ReduceLogSum"); + test.AddAttribute("axes", std::vector{1}); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + + 5.0f, 6.0f, + 7.0f, 8.0f, + 9.0f, 10.0f, + 11.0f, 12.0f}); + test.AddOutput("reduced", {3, 1, 2}, + {1.38629436f, 1.79175949f, + 2.48490667f, 2.6390574f, + 2.99573231f, 3.09104252f}); + test.Run(); +} + #if defined(USE_DNNL) TEST(ReductionOpTest, ReduceLogSum_bfloat16) { #ifdef USE_DNNL @@ -1820,6 +1872,23 @@ TEST(ReductionOpTest, ReduceMean_int32) { test.Run(); } +TEST(ReductionOpTest, ReduceMean_int64) { + OpTester test("ReduceMean"); + test.AddAttribute("axes", std::vector{0, 2}); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {10, 20, + 30, 40, + + 50, 60, + 70, 80, + + 90, 100, + 110, 120}); + test.AddOutput("reduced", {1, 2, 1}, {55, 75}); + test.Run(); +} + TEST(ReductionOpTest, ReduceMean_axes_input) { OpTester test("ReduceMean", 18, onnxruntime::kOnnxDomain); test.AddAttribute("keepdims", (int64_t)1); @@ -2989,6 +3058,22 @@ TEST(ReductionOpTest, ReduceProd) { test.Run(); } +TEST(ReductionOpTest, ReduceProd_double) { + OpTester test("ReduceProd"); + test.AddAttribute("axes", std::vector{0, 2}); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + + 5.0f, 6.0f, + 7.0f, 8.0f, + + 9.0f, 10.0f, + 11.0f, 12.0f}); + test.AddOutput("reduced", {1, 2, 1}, {5400.f, 88704.f}); + test.Run(); +} + TEST(ReductionOpTest, ReduceProdAxesInitializerOpset18) { OpTester test("ReduceProd", 18); test.AddInput("data", {3, 2, 2}, diff --git a/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc b/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc index b05649dafc18..30960e71c577 100644 --- a/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc @@ -98,8 +98,12 @@ static void RunGruTest(const std::vector& X_data, test.AddOptionalOutputEdge(); } - // TensorRT failed on GRU tests +// TensorRT, OpenVINO failed on GRU tests +#if defined(USE_OPENVINO) + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +#else test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +#endif } void DefaultActivationsSimpleWeightsNoBias(std::string direction, diff --git a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc index 5222380d9ca5..a0c1d675f506 100644 --- a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc @@ -373,5 +373,36 @@ TEST(TensorOpTest, DepthToSpaceTest_5) { test.Run(); } +TEST(TensorOpTest, DepthToSpaceTest_CRD_Batched) { + OpTester test("DepthToSpace", 11); // create an opset 11 model with attribute present = "CRD" mode + constexpr int64_t blocksize = 2; + test.AddAttribute("blocksize", blocksize); + test.AddAttribute("mode", "CRD"); + + constexpr int64_t N = 2, C = 4, H = 2, W = 3; + std::vector X = {0., 1., 2., + 3., 4., 5., + 9., 10., 11., + 12., 13., 14., + 18., 19., 20., + 21., 22., 23., + 27., 28., 29., + 30., 31., 32.}; + + // append same data but in reverse order so we can tell if the batch output is wrong + X.insert(X.end(), X.rbegin(), X.rend()); + + test.AddInput("input", {N, C, H, W}, X); + + std::vector result = {0., 9., 1., 10., 2., 11., + 18., 27., 19., 28., 20., 29., + 3., 12., 4., 13., 5., 14., + 21., 30., 22., 31., 23., 32.}; + result.insert(result.end(), result.rbegin(), result.rend()); + + test.AddOutput("output", {2, 1, 4, 6}, result); + test.Run(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc b/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc index 036c5760ed56..0a39413a4ec1 100644 --- a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc +++ b/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc @@ -80,8 +80,7 @@ template static GetTestModelFn BuildBatchNormTestCase(const TestInputDef& input_def, const TestInputDef& scale_def, const TestInputDef& bias_def) { - ORT_ENFORCE(input_def.IsRawData()); // Need raw data to compute mean and variance inputs. - ORT_ENFORCE(input_def.GetShape().size() > 2); // Need at least rank 3 data for convenience. + ORT_ENFORCE(input_def.IsRawData()); // Need raw data to compute mean and variance inputs. return [input_def, scale_def, bias_def](ModelTestBuilder& builder) { const auto& input_shape = input_def.GetShape(); @@ -103,45 +102,39 @@ static GetTestModelFn BuildBatchNormTestCase(const TestInputDef& inp }; } -template +template GetTestQDQModelFn BuildQDQBatchNormTestCase(const TestInputDef& input_def, const TestInputDef& scale_def, const TestInputDef& bias_def) { - ORT_ENFORCE(input_def.IsRawData()); // Need raw data to compute mean and variance inputs. - ORT_ENFORCE(input_def.GetShape().size() > 2); // Need at least rank 3 data for convenience. + ORT_ENFORCE(input_def.IsRawData()); // Need raw data to compute mean and variance inputs. return [input_def, scale_def, bias_def](ModelTestBuilder& builder, std::vector>& output_qparams) { const auto& input_shape = input_def.GetShape(); const auto& input_data = input_def.GetRawData(); const int64_t num_channels = input_shape[1]; - + bool symmetric = sizeof(InputQType) == sizeof(uint16_t); NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def, symmetric); NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); NodeArg* scale = MakeTestInput(builder, scale_def); QuantParams scale_qparams = GetTestInputQuantParams(scale_def); NodeArg* scale_qdq = AddQDQNodePair(builder, scale, scale_qparams.scale, scale_qparams.zero_point); - NodeArg* bias = MakeTestInput(builder, bias_def); - QuantParams bias_qparams = GetTestInputQuantParams(bias_def); - NodeArg* bias_qdq = AddQDQNodePair(builder, bias, bias_qparams.scale, bias_qparams.zero_point); + NodeArg* bias_qdq; + // bias (as int32) => DQ => + bias_qdq = MakeTestQDQBiasInput(builder, bias_def, input_qparams.scale * scale_qparams.scale, true); std::vector mean_vals(num_channels); std::vector var_vals(num_channels); ComputeChannelMeanAndVar(input_data, input_shape, mean_vals, var_vals); NodeArg* mean = builder.MakeInitializer({num_channels}, mean_vals); - QuantParams mean_qparams = GetDataQuantParams(mean_vals); - NodeArg* mean_qdq = AddQDQNodePair(builder, mean, mean_qparams.scale, mean_qparams.zero_point); - NodeArg* var = builder.MakeInitializer({num_channels}, var_vals); - QuantParams var_qparams = GetDataQuantParams(var_vals); - NodeArg* var_qdq = AddQDQNodePair(builder, var, var_qparams.scale, var_qparams.zero_point); auto* batchnorm_output = builder.MakeIntermediate(); - builder.AddNode("BatchNormalization", {input_qdq, scale_qdq, bias_qdq, mean_qdq, var_qdq}, + builder.AddNode("BatchNormalization", {input_qdq, scale_qdq, bias_qdq, mean, var}, {batchnorm_output}); AddQDQNodePairWithOutputAsGraphOutput(builder, batchnorm_output, output_qparams[0].scale, output_qparams[0].zero_point); @@ -155,6 +148,7 @@ GetTestQDQModelFn BuildQDQBatchNormTestCase(const TestInputDef static void RunBatchNormQDQTest(const TestInputDef& input_def, const TestInputDef& scale_def, const TestInputDef& bias_def, @@ -169,9 +163,9 @@ static void RunBatchNormQDQTest(const TestInputDef& input_def, // Runs model with DQ-> InstanceNorm -> Q and compares the outputs of the CPU and QNN EPs. TestQDQModelAccuracy(BuildBatchNormTestCase(input_def, scale_def, bias_def), - BuildQDQBatchNormTestCase(input_def, scale_def, bias_def), + BuildQDQBatchNormTestCase(input_def, scale_def, bias_def), provider_options, - 11, + 21, expected_ep_assignment, tolerance); } @@ -199,31 +193,69 @@ static void RunBatchNormFP16Test(const TestInputDef& input_def, expected_ep_assignment); } +// BatchNor QDQ model, input with rank 2. +TEST_F(QnnHTPBackendTests, BatchNormRank2) { + constexpr int64_t num_channels = 2; + + RunBatchNormQDQTest(TestInputDef({4, num_channels}, false, + {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f}), // Input data + TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer + TestInputDef({num_channels}, true, {1.1f, 2.1f}), // Bias initializer + ExpectedEPNodeAssignment::All); +} + // TODO: FIX TRANSLATION!!! // Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. // Use an input of rank 3. +// Accuracy issue with Linux simulator, not sure with Android device +// Inaccuracy detected for output 'output_0', element 1 +// output_range=4.8666362762451172, tolerance=0.40000000596046448%. +// Expected val (f32@CPU_EP): 1.0999999046325684 +// qdq@QNN_EP val: -0.17176364362239838 (err: 1.2717635631561279, err/output_range: 26.132291793823242%) +// qdq@CPU_EP val: 1.1069211959838867 (err: 0.0069212913513183594, err/output_range: 0.14221921563148499%) +// abs(qdq@QNN_EP - qdq@CPU_EP) / output_range = 25.990072250366211% +// +// Inaccuracy detected for output 'output_0', element 2 +// output_range=4.8666362762451172, tolerance=0.40000000596046448%. +// Expected val (f32@CPU_EP): 2.3247356414794922 +// qdq@QNN_EP val: -0.17176364362239838 (err: 2.4964993000030518, err/output_range: 51.298248291015625%) +// qdq@CPU_EP val: 2.3474364280700684 (err: 0.022700786590576172, err/output_range: 0.46645742654800415%) +#if defined(_WIN32) TEST_F(QnnHTPBackendTests, BatchNorm1D) { constexpr int64_t num_channels = 2; - RunBatchNormQDQTest(TestInputDef({1, num_channels, 3}, false, {-5.0f, -4.0f, -3.0f, 0.0f, 2.0f, 5.0f}), // Input data - TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer - TestInputDef({num_channels}, true, {1.1f, 2.1f}), // Bias initializer - ExpectedEPNodeAssignment::All); + RunBatchNormQDQTest(TestInputDef({1, num_channels, 3}, false, + {-5.0f, -4.0f, -3.0f, 0.0f, 2.0f, 5.0f}), // Input data + TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer + TestInputDef({num_channels}, true, {1.1f, 2.1f}), // Bias initializer + ExpectedEPNodeAssignment::All); +} +#endif + +// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. +// Use an input of rank 4. +TEST_F(QnnHTPBackendTests, BatchNorm2D_a8w8) { + constexpr int64_t num_channels = 2; + std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, + -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; + + RunBatchNormQDQTest(TestInputDef({2, num_channels, 2, 2}, false, input_data), // Input data + TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer + TestInputDef({num_channels}, true, {1.1f, 2.1f}), // Bias initializer + ExpectedEPNodeAssignment::All); } // Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. // Use an input of rank 4. -TEST_F(QnnHTPBackendTests, BatchNorm2D) { +TEST_F(QnnHTPBackendTests, BatchNorm2D_a16w8) { constexpr int64_t num_channels = 2; std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; - RunBatchNormQDQTest(TestInputDef({2, num_channels, 2, 2}, false, input_data), // Input data - TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer - TestInputDef({num_channels}, true, {1.1f, 2.1f}), // Bias initializer - ExpectedEPNodeAssignment::All, - // Require a slightly increased tolerance on Windows ARM64 (from 0.4% to 0.6%). - QDQTolerance(0.006f)); + RunBatchNormQDQTest(TestInputDef({2, num_channels, 2, 2}, false, input_data), // Input data + TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer + TestInputDef({num_channels}, true, {1.1f, 2.1f}), // Bias initializer + ExpectedEPNodeAssignment::All); } // Test FP16 BatchNormalization on the HTP backend. @@ -272,10 +304,11 @@ TEST_F(QnnHTPBackendTests, BatchNorm_FP32_as_FP16) { TEST_F(QnnHTPBackendTests, BatchNorm3D) { constexpr int64_t num_channels = 2; constexpr int64_t num_elems = 1 * num_channels * 3 * 4 * 5; - RunBatchNormQDQTest(TestInputDef({1, num_channels, 3, 4, 5}, false, std::vector(num_elems)), // Input data (all zeros) - TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer - TestInputDef({num_channels}, true, {1.1f, 2.1f}), // Bias initializer - ExpectedEPNodeAssignment::None); + RunBatchNormQDQTest(TestInputDef({1, num_channels, 3, 4, 5}, false, + std::vector(num_elems)), // Input data (all zeros) + TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer + TestInputDef({num_channels}, true, {1.1f, 2.1f}), // Bias initializer + ExpectedEPNodeAssignment::None); } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index 99636976b9c0..95673586677e 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -15,6 +15,12 @@ namespace onnxruntime { namespace test { +// Information for activation node placed between the Conv and Q. +struct OutputActivationInfo { + std::string op_type; // Relu or Clip + std::vector const_inputs; +}; + // Creates a graph with a single float32 Conv operator. Used for testing CPU backend. static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, const TestInputDef& input_def, const TestInputDef& weights_def, @@ -23,9 +29,10 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons const std::vector& pads, const std::vector& dilations, std::optional group, - const std::string& auto_pad = "NOTSET") { + const std::string& auto_pad = "NOTSET", + std::optional output_activation = std::nullopt) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, - dilations, group, auto_pad](ModelTestBuilder& builder) { + dilations, group, auto_pad, output_activation](ModelTestBuilder& builder) { std::vector conv_inputs = { MakeTestInput(builder, input_def), MakeTestInput(builder, weights_def)}; @@ -34,9 +41,9 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons conv_inputs.push_back(MakeTestInput(builder, bias_def)); } - auto* output = builder.MakeOutput(); + auto* conv_output = output_activation.has_value() ? builder.MakeIntermediate() : builder.MakeOutput(); - Node& conv_node = builder.AddNode(conv_op_type, conv_inputs, {output}); + Node& conv_node = builder.AddNode(conv_op_type, conv_inputs, {conv_output}); conv_node.AddAttribute("auto_pad", auto_pad); if (group.has_value()) { @@ -54,6 +61,15 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons if (!dilations.empty()) { conv_node.AddAttribute("dilations", dilations); } + + if (output_activation.has_value()) { + NodeArg* output = builder.MakeOutput(); + std::vector activation_inputs = {conv_output}; + for (auto val : output_activation->const_inputs) { + activation_inputs.push_back(builder.MakeScalarInitializer(val)); + } + builder.AddNode(output_activation->op_type, activation_inputs, {output}); + } }; } @@ -88,19 +104,22 @@ static void RunCPUConvOpTest(const std::string& conv_op_type, const TestInputDef // Creates a graph with a single Q/DQ Conv operator. Used for testing HTP backend. template -static GetTestQDQModelFn BuildQDQConvTestCase(const std::string& conv_op_type, - const TestInputDef& input_def, - const TestInputDef& weights_def, - const TestInputDef& bias_def, - const std::vector& strides, - const std::vector& pads, - const std::vector& dilations, - std::optional group, - const std::string& auto_pad = "NOTSET", - bool use_contrib_qdq = false) { +static GetTestQDQModelFn BuildQDQConvTestCase( + const std::string& conv_op_type, + const TestInputDef& input_def, + const TestInputDef& weights_def, + const TestInputDef& bias_def, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + std::optional group, + const std::string& auto_pad = "NOTSET", + bool use_contrib_qdq = false, + std::optional output_activation = std::nullopt) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, - dilations, group, auto_pad, use_contrib_qdq](ModelTestBuilder& builder, - std::vector>& output_qparams) { + dilations, group, auto_pad, + use_contrib_qdq, output_activation](ModelTestBuilder& builder, + std::vector>& output_qparams) { std::vector conv_inputs; // input -> Q/DQ -> @@ -144,27 +163,39 @@ static GetTestQDQModelFn BuildQDQConvTestCase(const std::string conv_node.AddAttribute("dilations", dilations); } - AddQDQNodePairWithOutputAsGraphOutput(builder, conv_output, output_qparams[0].scale, + NodeArg* q_input = conv_output; + if (output_activation.has_value()) { + q_input = builder.MakeIntermediate(); + std::vector activation_inputs = {conv_output}; + for (auto val : output_activation->const_inputs) { + activation_inputs.push_back(builder.MakeScalarInitializer(val)); + } + builder.AddNode(output_activation->op_type, activation_inputs, {q_input}); + } + + AddQDQNodePairWithOutputAsGraphOutput(builder, q_input, output_qparams[0].scale, output_qparams[0].zero_point, use_contrib_qdq); }; } template -static GetTestQDQModelFn BuildQDQPerChannelConvTestCase(const std::string& conv_op_type, - const TestInputDef& input_def, - const TestInputDef& weights_def, - const TestInputDef& bias_def, - int64_t weight_quant_axis, - const std::vector& strides, - const std::vector& pads, - const std::vector& dilations, - std::optional group, - const std::string& auto_pad = "NOTSET", - bool use_contrib_qdq = false) { +static GetTestQDQModelFn BuildQDQPerChannelConvTestCase( + const std::string& conv_op_type, + const TestInputDef& input_def, + const TestInputDef& weights_def, + const TestInputDef& bias_def, + int64_t weight_quant_axis, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + std::optional group, + const std::string& auto_pad = "NOTSET", + bool use_contrib_qdq = false, + std::optional output_activation = std::nullopt) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, group, auto_pad, use_contrib_qdq, - weight_quant_axis](ModelTestBuilder& builder, - std::vector>& output_qparams) { + weight_quant_axis, output_activation](ModelTestBuilder& builder, + std::vector>& output_qparams) { std::vector conv_inputs; // input -> Q/DQ -> @@ -248,7 +279,17 @@ static GetTestQDQModelFn BuildQDQPerChannelConvTestCase(const s conv_node.AddAttribute("dilations", dilations); } - AddQDQNodePairWithOutputAsGraphOutput(builder, conv_output, output_qparams[0].scale, + NodeArg* q_input = conv_output; + if (output_activation.has_value()) { + q_input = builder.MakeIntermediate(); + std::vector activation_inputs = {conv_output}; + for (auto val : output_activation->const_inputs) { + activation_inputs.push_back(builder.MakeScalarInitializer(val)); + } + builder.AddNode(output_activation->op_type, activation_inputs, {q_input}); + } + + AddQDQNodePairWithOutputAsGraphOutput(builder, q_input, output_qparams[0].scale, output_qparams[0].zero_point, use_contrib_qdq); }; } @@ -267,7 +308,8 @@ static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef ExpectedEPNodeAssignment expected_ep_assignment, bool use_contrib_qdq = false, int opset = 13, - QDQTolerance tolerance = QDQTolerance()) { + QDQTolerance tolerance = QDQTolerance(), + std::optional output_activation = std::nullopt) { ProviderOptions provider_options; #if defined(_WIN32) @@ -277,10 +319,11 @@ static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef #endif TestQDQModelAccuracy(BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, - group, auto_pad), + group, auto_pad, output_activation), BuildQDQConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, - group, auto_pad, use_contrib_qdq), + group, auto_pad, use_contrib_qdq, + output_activation), provider_options, opset, expected_ep_assignment, @@ -302,7 +345,8 @@ static void RunHTPConvOpPerChannelTest(const std::string& conv_op_type, const Te ExpectedEPNodeAssignment expected_ep_assignment, bool use_contrib_qdq = false, int opset = 13, - QDQTolerance tolerance = QDQTolerance()) { + QDQTolerance tolerance = QDQTolerance(), + std::optional output_activation = std::nullopt) { ProviderOptions provider_options; #if defined(_WIN32) @@ -312,11 +356,11 @@ static void RunHTPConvOpPerChannelTest(const std::string& conv_op_type, const Te #endif auto f32_fn = BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, - group, auto_pad); + group, auto_pad, output_activation); auto qdq_fn = BuildQDQPerChannelConvTestCase(conv_op_type, input_def, weights_def, bias_def, weight_quant_axis, strides, pads, dilations, group, auto_pad, - use_contrib_qdq); + use_contrib_qdq, output_activation); TestQDQModelAccuracy(f32_fn, qdq_fn, provider_options, opset, expected_ep_assignment, tolerance); } @@ -764,6 +808,140 @@ TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel) { 21); // opset } +// Test fusion of DQs -> Conv -> Relu/Clip -> Q. +// User per-tensor quantization. +TEST_F(QnnHTPBackendTests, ConvU8U8S32_ReluClipFusion) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(0.0f, 1.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + // DQs -> Conv (w/ bias) -> Relu -> Q + OutputActivationInfo relu_info = {"Relu", {}}; + RunHTPConvOpTest("Conv", + input_def, + weight_def, + bias_def, + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + relu_info); + + // DQs -> Conv (NO bias) -> Relu -> Q + RunHTPConvOpTest("Conv", + input_def, + weight_def, + TestInputDef(), + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + relu_info); + + // DQs -> Conv (w/ bias) -> Clip -> Q + // Opset 6 Clip uses attributes for min/max + OutputActivationInfo clip_info = {"Clip", {0.0f, 2.0f}}; + RunHTPConvOpTest("Conv", + input_def, + weight_def, + bias_def, + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 19, // opset + QDQTolerance(), + clip_info); + + // DQs -> Conv (NO bias) -> Clip -> Q + OutputActivationInfo clip_info_2 = {"Clip", {-6.0f, 6.0f}}; + RunHTPConvOpTest("Conv", + input_def, + weight_def, + TestInputDef(), + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + clip_info_2); +} + +// Test fusion of DQs -> Conv -> Relu/Clip -> Q. +// User per-channel quantization. +TEST_F(QnnHTPBackendTests, ConvS8S8S32_PerChannel_ReluClipFusion) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(0.0f, 1.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + // DQs -> Conv (w/ bias) -> Relu -> Q + OutputActivationInfo relu_info = {"Relu", {}}; + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + 0, // weight quant axis + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + relu_info); + + // DQs -> Conv (w/ bias) -> Clip -> Q + OutputActivationInfo clip_info = {"Clip", {0.0f, 6.0f}}; + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + 0, // weight quant axis + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + clip_info); +} + // Test per-channel QDQ Conv with INT4 weights and a negative weight quantization axis that still points to dimension 0. TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel_NegativeWeightQuantAxis) { std::vector input_shape = {1, 2, 4, 4}; @@ -799,7 +977,7 @@ TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel_NegativeWeightQuantAxis) { // CPU EP (f32 model): 25.143 21.554 17.964 10.785 7.195 3.605 -3.574 -7.164 -10.753 // CPU EP (qdq model): 24.670 21.103 17.536 10.254 6.689 2.972 -4.161 -7.728 -10.700 // QNN EP (qdq model): 27.186 27.186 27.186 21.541 6.685 -8.022 -10.548 -10.548 -10.548 -TEST_F(QnnHTPBackendTests, DISABLED_ConvU16S4S32_PerChannel_AccuracyIssue) { +TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel_AccuracyIssue) { std::vector input_shape = {1, 2, 4, 4}; std::vector weight_shape = {3, 2, 2, 2}; std::vector bias_shape = {3}; @@ -835,7 +1013,8 @@ TEST_F(QnnHTPBackendTests, DISABLED_ConvU16S4S32_PerChannel_AccuracyIssue) { "NOTSET", ExpectedEPNodeAssignment::All, false, // use_qdq_contrib_ops - 21); // opset + 21, // opset + QDQTolerance(0.005f)); } // Test per-channel QDQ Conv is rejected with weight axis != 0 diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index dba60b104169..d8c34d6a6c6e 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -28,26 +28,25 @@ static GetTestModelFn BuildMatMulOpTestCase(const TestInputDef& input1_de // Returns a function that creates a graph with a QDQ MatMul operator. template -static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDef& input1_def, - const TestInputDef& input2_def, +static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDef& input0_def, + const TestInputDef& input1_def, bool use_contrib_qdq) { - return [input1_def, input2_def, use_contrib_qdq](ModelTestBuilder& builder, + return [input0_def, input1_def, use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { // input1 -> Q -> DQ -> - NodeArg* input1 = MakeTestInput(builder, input1_def); - QuantParams input1_qparams = GetTestInputQuantParams(input1_def); - auto* input1_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, input1_qparams.zero_point, + NodeArg* input0 = MakeTestInput(builder, input0_def); + QuantParams input0_qparams = GetTestInputQuantParams(input0_def); + auto* input0_qdq = AddQDQNodePair(builder, input0, input0_qparams.scale, input0_qparams.zero_point, use_contrib_qdq); - - // input2 -> Q -> DQ -> - NodeArg* input2 = MakeTestInput(builder, input2_def); - QuantParams input2_qparams = GetTestInputQuantParams(input2_def); - auto* input2_qdq = AddQDQNodePair(builder, input2, input2_qparams.scale, input2_qparams.zero_point, + // input1 -> Q -> DQ -> + NodeArg* input1 = MakeTestInput(builder, input1_def); + QuantParams input1_qparams = GetTestInputQuantParams(input1_def); + auto* input1_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, input1_qparams.zero_point, use_contrib_qdq); // MatMul auto* op_output = builder.MakeIntermediate(); - builder.AddNode("MatMul", {input1_qdq, input2_qdq}, {op_output}); + builder.AddNode("MatMul", {input0_qdq, input1_qdq}, {op_output}); // op_output -> Q -> DQ -> output AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, @@ -55,6 +54,88 @@ static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDe }; } +template +static GetTestQDQModelFn BuildQDQPerChannelMatMulTestCase(const TestInputDef& input_def, + const TestInputDef& weights_def, + int64_t weight_quant_axis, + bool use_contrib_qdq = false) { + return [input_def, weights_def, weight_quant_axis, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + std::vector matmul_inputs; + + // input -> Q/DQ -> + auto* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + auto* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + matmul_inputs.push_back(input_qdq); + + // Quantized(weights) -> DQ -> + ORT_ENFORCE(weights_def.IsInitializer() && weights_def.IsRawData()); + std::vector weight_scales; + std::vector weight_zero_points; + TensorShape weights_shape = weights_def.GetTensorShape(); + int64_t pos_weight_quant_axis = weight_quant_axis; + if (pos_weight_quant_axis < 0) { + pos_weight_quant_axis += static_cast(weights_shape.NumDimensions()); + } + GetTestInputQuantParamsPerChannel(weights_def, weight_scales, weight_zero_points, + static_cast(pos_weight_quant_axis), true); + + std::vector quantized_weights; + size_t num_weight_storage_elems = weights_shape.Size(); + if constexpr (std::is_same_v || std::is_same_v) { + num_weight_storage_elems = Int4x2::CalcNumInt4Pairs(weights_shape.Size()); + } + quantized_weights.resize(num_weight_storage_elems); + QuantizeValues(weights_def.GetRawData(), quantized_weights, weights_shape, + weight_scales, weight_zero_points, pos_weight_quant_axis); + + NodeArg* weights_initializer = builder.MakeInitializer(weights_def.GetShape(), quantized_weights); + NodeArg* weights_dq = builder.MakeIntermediate(); + Node& weights_dq_node = builder.AddDequantizeLinearNode(weights_initializer, weight_scales, + weight_zero_points, weights_dq, + nullptr, use_contrib_qdq); + weights_dq_node.AddAttribute("axis", weight_quant_axis); + matmul_inputs.push_back(weights_dq); + + auto* matmul_output = builder.MakeIntermediate(); + builder.AddNode("MatMul", matmul_inputs, {matmul_output}); + + AddQDQNodePairWithOutputAsGraphOutput(builder, matmul_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +// Runs a QDQ per-channel MatMul model on the QNN HTP backend. Checks the graph node assignment, and that the +// QDQ model is accurate on QNN EP (compared to CPU EP). +template +static void RunQDQPerChannelMatMulOpOpTest(const TestInputDef& input_def, + const TestInputDef& weights_def, + int64_t weight_quant_axis, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 21, + bool use_contrib_qdq = false, + QDQTolerance tolerance = QDQTolerance()) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + TestQDQModelAccuracy(BuildMatMulOpTestCase(input_def, weights_def), + BuildQDQPerChannelMatMulTestCase(input_def, + weights_def, + weight_quant_axis, + use_contrib_qdq), + provider_options, + opset, + expected_ep_assignment, + tolerance); +} + // Runs an MatMul model on the QNN CPU backend. Checks the graph node assignment, and that inference // outputs for QNN and CPU match. static void RunMatMulOpOpTest(const TestInputDef& input1_def, @@ -160,6 +241,55 @@ TEST_F(QnnHTPBackendTests, MatMulOp_HTP_A16_W8Static) { true); // Use com.microsoft Q/DQ ops } +// Test QDQ per-channel MatMul with 16-bit act, signed 4-bit weights (static) +TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightInt4) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), + TestInputDef({1, 1, 3, 2}, true, input1_data), + 1, // quantization axis + ExpectedEPNodeAssignment::All, + 21, + false); +} + +// Test QDQ per-channel MatMul with 16-bit act, unsigned 4-bit weights (static) +TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightUInt4) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), + TestInputDef({1, 1, 3, 2}, true, input1_data), + 1, // quantization axis + ExpectedEPNodeAssignment::All, + 21, + false); +} + +// Test QDQ per-channel MatMul with int8 act, int4 weights (static) +TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_AS8_WeightInt4) { + std::vector input0_data = GetFloatDataInRange(-5.0f, 5.0f, 6); + std::vector input1_data = {-2.0f, -1.0f, -0.5f, 0.0f, 1.0f, 2.0f}; + RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), + TestInputDef({1, 1, 3, 2}, true, input1_data), + 1, // quantization axis + ExpectedEPNodeAssignment::All, + 21, + false, + QDQTolerance(0.007f)); +} + +// Test QDQ per-channel MatMul with 16-bit act, int8 weights (static) +TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightInt8) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), + TestInputDef({1, 1, 3, 2}, true, input1_data), + 1, // quantization axis + ExpectedEPNodeAssignment::All, + 21, + false); +} + // Test QDQ MatMul with uint16 activation uint16 weights, both dynamic // Inaccuracy detected for output 'output_0', element 1. // Output quant params: scale=0.0015259021893143654, zero_point=0. diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index a3768cb98f58..be3bd2cc5dcd 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -279,6 +279,45 @@ TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } +TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + std::string node_name_prefix = "node_name_prefix_test"; + + // Add kMSDomain to cover contrib op like Gelu + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextNodeNamePrefix, node_name_prefix.c_str()); + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + for (auto& node : model->MainGraph().Nodes()) { + if (node.OpType() == "EPContext") { + EXPECT_TRUE(node.Name().find(node_name_prefix) != std::string::npos); + } + } + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + // Run QDQ model on HTP 3 times // 1st run will generate the Qnn context cache onnx file // 2nd run directly loads and run from Qnn context cache model diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index eb03270dc846..3a6753e9b613 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -42,7 +42,7 @@ struct QuantParams { symmetric); } - static QuantParams Compute(float rmin, float rmax, QType qmin, QType qmax, bool symmetric = false) { + static QuantParams Compute(float rmin, float rmax, float qmin, float qmax, bool symmetric = false) { // Ensure a minimum range of 0.0001 (required by QNN) rmax = std::max(rmax, rmin + 0.0001f); @@ -56,8 +56,8 @@ struct QuantParams { rmin = -abs_max; } - float qmin_flt = static_cast(qmin); - float qmax_flt = static_cast(qmax); + float qmin_flt = qmin; + float qmax_flt = qmax; const float scale = (rmax - rmin) / (qmax_flt - qmin_flt); float initial_zero_point = 0.0f; diff --git a/onnxruntime/test/python/onnxruntime_test_python_backend_mlops.py b/onnxruntime/test/python/onnxruntime_test_python_backend_mlops.py index b5400b487cfc..c245699e211d 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_backend_mlops.py +++ b/onnxruntime/test/python/onnxruntime_test_python_backend_mlops.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# -*- coding: UTF-8 -*- import unittest import numpy as np @@ -10,7 +9,7 @@ import onnxruntime.backend as backend from onnxruntime import datasets -from onnxruntime.backend.backend import OnnxRuntimeBackend as ort_backend # noqa: N813 +from onnxruntime.backend.backend import OnnxRuntimeBackend as ort_backend def check_list_of_map_to_float(testcase, expected_rows, actual_rows): diff --git a/onnxruntime/test/python/transformers/benchmark_mha.cmd b/onnxruntime/test/python/transformers/benchmark_mha.cmd new file mode 100644 index 000000000000..0a6d0c37b4a3 --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_mha.cmd @@ -0,0 +1,47 @@ +echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" + +set CUDA_VISIBLE_DEVICES=0 +python benchmark_mha.py --use_gpu +python benchmark_mha.py --use_gpu --use_cuda_graph +python benchmark_mha.py --use_gpu --torch + +type benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv + +echo "Benchmark performance on CPU with number of threads:" +set MKL_DYNAMIC=FALSE +set OMP_NUM_THREADS=1 +python benchmark_mha.py --torch + +set OMP_NUM_THREADS=2 +python benchmark_mha.py --torch + +set OMP_NUM_THREADS=4 +python benchmark_mha.py --torch + +set OMP_NUM_THREADS=8 +python benchmark_mha.py --torch + +set MKL_DYNAMIC= +set OMP_NUM_THREADS= + +set ORT_DISABLE_FLASH_ATTENTION=0 +python benchmark_mha.py --intra_op_num_threads 1 +python benchmark_mha.py --intra_op_num_threads 2 +python benchmark_mha.py --intra_op_num_threads 4 +python benchmark_mha.py --intra_op_num_threads 8 + +echo "Benchmark performance on CPU with default threads settings:" +python benchmark_mha.py + +python benchmark_mha.py --torch + +python benchmark_mha.py --causal +python benchmark_mha.py --torch --causal + +python benchmark_mha.py --causal --has_past + +set ORT_DISABLE_FLASH_ATTENTION=1 +python benchmark_mha.py +set ORT_DISABLE_FLASH_ATTENTION= + +type benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 111c417479d2..ec350874af32 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -4,21 +4,35 @@ # -------------------------------------------------------------------------- """ -Benchmark performance of MultiHeadAttention with Nvidia GPU of Compute Capability 8.0, 8.6 or 8.9 in Linux: -sh benchmark_mha.sh +Benchmark performance of MultiHeadAttention with ORT or PyTorch. + +In Linux, run the the following: + sh benchmark_mha.sh + +In Windows, run the the following: + benchmark_mha.cmd """ +import argparse +import csv import math import os import platform import statistics import time -from typing import List, Optional +from contextlib import nullcontext +from datetime import datetime +from enum import IntEnum +from typing import Callable, Dict, List, Optional, Tuple import torch +import torch.utils.benchmark as benchmark from onnx import TensorProto, helper +from packaging.version import Version +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.nn.functional import scaled_dot_product_attention -from onnxruntime import InferenceSession, get_available_providers +from onnxruntime import InferenceSession, SessionOptions, get_available_providers from onnxruntime.transformers.io_binding_helper import CudaSession @@ -43,6 +57,20 @@ def get_name_list() -> List[str]: return ["Q,K,V", "QKV", "Q,KV", "Q,K',V'"] +class SdpaKernel(IntEnum): + """Bit flags for sdpa_kernel CUDA provider option""" + + DEFAULT = 0 + FLASH_ATTENTION = 1 + EFFICIENT_ATTENTION = 2 + TRT_FUSED_ATTENTION = 4 + CUDNN_FLASH_ATTENTION = 8 + MATH = 16 + TRT_FLASH_ATTENTION = 32 + TRT_CROSS_ATTENTION = 64 + TRT_CAUSAL_ATTENTION = 128 + + class MultiHeadAttentionConfig: def __init__( self, @@ -60,8 +88,11 @@ def __init__( enable_cuda_graph: bool = False, dtype=torch.float, use_kv_cache: bool = False, + has_past_input: bool = False, share_past_present_buffer: bool = False, input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH, + verbose: bool = False, + has_bias: bool = False, ): self.operator = "MultiHeadAttention" self.batch_size = batch_size @@ -74,15 +105,25 @@ def __init__( self.causal = causal self.softmax_scale = softmax_scale or (1.0 / (head_size**0.5)) + # Support the case that there is no past but need present output (for prompt case). + self.has_past_input = has_past_input + if has_past_input: + assert use_kv_cache + else: # no past input + assert past_sequence_length == 0 + + self.has_present_output = use_kv_cache + self.use_kv_cache = use_kv_cache if not use_kv_cache: assert past_sequence_length == 0 else: assert self.kv_sequence_length == self.sequence_length - if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - # cross attention does not have past state - assert not use_kv_cache + # Only BSNH input format supports past state. + if input_format != InputFormats.Q_K_V_BSNH_BSNH_BSNH: + assert not self.has_past_input + assert not self.has_present_output # Derived values self.total_sequence_length = self.kv_sequence_length + past_sequence_length @@ -100,6 +141,8 @@ def __init__( self.input_format = input_format self.is_packed_qkv = input_format == InputFormats.QKV_BSN3H self.is_packed_kv = input_format == InputFormats.Q_KV_BSNH_BSN2H + self.verbose = verbose + self.has_bias = has_bias def __repr__(self): return ( @@ -110,96 +153,115 @@ def __repr__(self): f"causal={self.causal}), softmax_scale={self.softmax_scale}, use_kv_cache={self.use_kv_cache}, " f"share_past_present_buffer={self.share_past_present_buffer}, " f"provider={self.provider}, device={self.device}, enable_cuda_graph={self.enable_cuda_graph}, " - f"dtype={self.dtype}, input_format={InputFormats.input_format_str(self.input_format)}" + f"dtype={self.dtype}, input_format={InputFormats.input_format_str(self.input_format)}, " + f"has_bias={self.has_bias}" ) def shape_dict(self, input_format=None): + shapes: Dict[str, Tuple] = { + "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + } + input_format = input_format or self.input_format - if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - # cross attention does not have past state - return { + if input_format == InputFormats.QKV_BSN3H: + shapes = { + **shapes, + "query": (self.batch_size, self.sequence_length, self.num_heads, 3, self.head_size), + } + elif input_format == InputFormats.Q_KV_BSNH_BSN2H: + shapes = { + **shapes, + "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "key": (self.batch_size, self.sequence_length, self.num_heads, 2, self.head_size), + } + elif input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + shapes = { + **shapes, + "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "key": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "value": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + } + else: + assert input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH + shapes = { + **shapes, "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), "key": (self.batch_size, self.num_heads, self.sequence_length, self.head_size), "value": (self.batch_size, self.num_heads, self.sequence_length, self.head_size), - "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), } - if self.use_kv_cache: + if self.has_past_input: shapes = { + **shapes, "past_key": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), "past_value": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), - "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "present_key": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), - "present_value": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), } - else: + + if self.has_present_output: shapes = { - "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + **shapes, + "present_key": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), + "present_value": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), } - if input_format == InputFormats.QKV_BSN3H: - shapes.update({"query": (self.batch_size, self.sequence_length, self.num_heads, 3, self.head_size)}) - elif input_format == InputFormats.Q_KV_BSNH_BSN2H: - shapes.update( - { - "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "key": (self.batch_size, self.sequence_length, self.num_heads, 2, self.head_size), - } - ) - else: # input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH - shapes.update( - { - "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "key": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "value": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - } - ) + if self.has_bias: + shapes["bias"] = (3 * self.num_heads * self.head_size,) + return shapes def symbolic_shape_dict(self, input_format=None): + shapes: Dict[str, Tuple] = { + "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), + } + input_format = input_format or self.input_format - if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - # cross attention does not have past state - return { + if input_format == InputFormats.QKV_BSN3H: + shapes = { + **shapes, + "query": ("batch_size", "sequence_length", self.num_heads, 3, self.head_size), + } + elif input_format == InputFormats.Q_KV_BSNH_BSN2H: + shapes = { + **shapes, + "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "key": ("batch_size", "sequence_length", self.num_heads, 2, self.head_size), + } + elif input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + shapes = { + **shapes, + "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "key": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "value": ("batch_size", "sequence_length", self.num_heads * self.head_size), + } + else: + assert input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH + shapes = { + **shapes, "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), "key": ("batch_size", self.num_heads, "sequence_length", self.head_size), "value": ("batch_size", self.num_heads, "sequence_length", self.head_size), - "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), } - if self.use_kv_cache: + if self.has_past_input: shapes = { + **shapes, "past_key": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), "past_value": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), - "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), - "present_key": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), - "present_value": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), } - else: + + if self.has_present_output: shapes = { - "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), + **shapes, + "present_key": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), + "present_value": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), } - if input_format == InputFormats.QKV_BSN3H: - shapes.update({"query": ("batch_size", "sequence_length", self.num_heads, 3, self.head_size)}) - elif input_format == InputFormats.Q_KV_BSNH_BSN2H: - shapes.update( - { - "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), - "key": ("batch_size", "sequence_length", self.num_heads, 2, self.head_size), - } - ) - else: # input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH - shapes.update( - { - "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), - "key": ("batch_size", "sequence_length", self.num_heads * self.head_size), - "value": ("batch_size", "sequence_length", self.num_heads * self.head_size), - } - ) + if self.has_bias: + shapes["bias"] = (3 * self.num_heads * self.head_size,) + return shapes - def random_inputs(self, seed: int = 123): + def random_inputs(self, seed: int = 123, no_bias_k_v: bool = False): device = self.device dtype = self.dtype @@ -212,47 +274,56 @@ def random_inputs(self, seed: int = 123): q = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) k = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) v = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) + + bias_q = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1) + bias_k = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1) + bias_v = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1) + if no_bias_k_v: + bias_k = torch.zeros_like(bias_k) + bias_v = torch.zeros_like(bias_v) + k_bnsh = k.transpose(1, 2) v_bnsh = v.transpose(1, 2) - if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - return { + if self.input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + feeds = { "query": q.reshape(shape_dict["query"]), - "key": k_bnsh.contiguous(), - "value": v_bnsh.contiguous(), + "key": k.reshape(shape_dict["key"]), + "value": v.reshape(shape_dict["value"]), } - - feeds = {} - if self.use_kv_cache: - feeds.update( - { - "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_( - mean=0, std=0.1 - ), - "past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_( - mean=0, std=0.1 - ), - } - ) - - if self.input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: - feeds.update( - { - "query": q.reshape(shape_dict["query"]), - "key": k.reshape(shape_dict["key"]), - "value": v.reshape(shape_dict["value"]), - } - ) elif self.input_format == InputFormats.QKV_BSN3H: query = q.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) - feeds["query"] = torch.dstack((query, key, value)).reshape(shape_dict["query"]).contiguous() + feeds = { + "query": torch.dstack((query, key, value)).reshape(shape_dict["query"]).contiguous(), + } elif self.input_format == InputFormats.Q_KV_BSNH_BSN2H: key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) - feeds["query"] = q.reshape(shape_dict["query"]) - feeds["key"] = torch.dstack((key, value)).reshape(shape_dict["key"]).contiguous() + feeds = { + "query": q.reshape(shape_dict["query"]), + "key": torch.dstack((key, value)).reshape(shape_dict["key"]).contiguous(), + } + else: + assert self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH + feeds = { + "query": q.reshape(shape_dict["query"]), + "key": k_bnsh.contiguous(), + "value": v_bnsh.contiguous(), + } + + if self.has_past_input: + feeds = { + **feeds, + "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + "past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_( + mean=0, std=0.1 + ), + } + + if self.has_bias: + feeds["bias"] = torch.concat([bias_q, bias_k, bias_v], dim=0).reshape(shape_dict["bias"]).contiguous() return feeds @@ -267,15 +338,29 @@ def get_input_output_names(self): else: inputs, outputs = ["query", "key", "value"], ["output"] - if self.use_kv_cache: - return [*inputs, "past_key", "past_value"], [*outputs, "present_key", "present_value"] - else: - return inputs, outputs + if self.has_bias: + inputs = [*inputs, "bias"] + + if self.has_past_input: + inputs = [*inputs, "past_key", "past_value"] + + if self.has_present_output: + outputs = [*outputs, "present_key", "present_value"] + + return inputs, outputs def fill_optional_mha_inputs(input_names): inputs = ["query", "key", "value", "bias", "key_padding_mask", "relative_position_bias", "past_key", "past_value"] - return input_names[:-2] + [""] * (len(inputs) - len(input_names)) + input_names[-2:] + + # Remove optional inputs that are not in input_names with empty string + inputs_with_optional = [input if input in input_names else "" for input in inputs] + + # Remove empty string at the end of the list. + while inputs_with_optional[-1] == "": + inputs_with_optional.pop(-1) + + return inputs_with_optional def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use_symbolic_shape=False): @@ -285,7 +370,7 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use nodes = [ helper.make_node( "MultiHeadAttention", - fill_optional_mha_inputs(input_names) if config.use_kv_cache else input_names, + fill_optional_mha_inputs(input_names), output_names, "MultiHeadAttention_0", num_heads=config.num_heads, @@ -299,11 +384,13 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use inputs = [ helper.make_tensor_value_info(input_name, float_type, list(shape_dict[input_name])) for input_name in input_names + if input_name ] outputs = [ helper.make_tensor_value_info(output_name, float_type, list(shape_dict[output_name])) for output_name in output_names + if output_name ] graph = helper.make_graph( @@ -318,19 +405,36 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use return model.SerializeToString() -def create_session( +def create_ort_session( config: MultiHeadAttentionConfig, + session_options=None, + attention_kernel=SdpaKernel.DEFAULT, + use_symbolic_shape: bool = True, + use_tf32: bool = True, ) -> CudaSession: - onnx_model_str = create_multi_head_attention_onnx_model(config) + if config.verbose: + print(f"create session for {vars(config)}") + onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=use_symbolic_shape) if config.provider == "CUDAExecutionProvider": device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph) + provider_options["sdpa_kernel"] = int(attention_kernel) + provider_options["use_tf32"] = int(use_tf32) providers = [(config.provider, provider_options), "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] - ort_session = InferenceSession(onnx_model_str, providers=providers) + ort_session = InferenceSession(onnx_model_str, session_options, providers=providers) + return ort_session + + +def create_session( + config: MultiHeadAttentionConfig, session_options=None, attention_kernel=SdpaKernel.DEFAULT, use_tf32: bool = True +) -> CudaSession: + ort_session = create_ort_session( + config, session_options, attention_kernel, use_symbolic_shape=False, use_tf32=use_tf32 + ) cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph) shape_dict = config.shape_dict() cuda_session.allocate_buffers(shape_dict) @@ -340,11 +444,8 @@ def create_session( class OrtMultiHeadAttention: """A wrapper of ORT MultiHeadAttention to test relevance and performance.""" - def __init__( - self, - config: MultiHeadAttentionConfig, - ): - self.ort_session = create_session(config) + def __init__(self, config: MultiHeadAttentionConfig, session_options=None, use_tf32: bool = True): + self.ort_session = create_session(config, session_options, use_tf32=use_tf32) self.feed_dict = config.random_inputs() def infer(self): @@ -363,53 +464,90 @@ def flops(batch, sequence_length, head_size, num_heads, causal): def tflops_per_second(flop, time): - return (flop / time / 10**12) if not math.isnan(time) else 0.0 - - -def get_gpu_kernel_name(config: MultiHeadAttentionConfig) -> str: - # This classification is for Nvidia GPU of Compute Capability 8.* like A100. - # Note that some kernel might not exist in older or newer GPUs. - if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": - if config.input_format == InputFormats.QKV_BSN3H: - min_seq_len = os.getenv("ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV") - min_length = int(min_seq_len) if min_seq_len is not None else 513 - if config.sequence_length >= min_length: - return "Flash" - else: - return "Flash" + try: + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + except ZeroDivisionError: + return None + + +def get_gpu_kernel_name(attention_kernel: SdpaKernel) -> str: + kernel_names = { + SdpaKernel.DEFAULT: "ort:default", + SdpaKernel.FLASH_ATTENTION: "ort:flash", + SdpaKernel.EFFICIENT_ATTENTION: "ort:efficient", + SdpaKernel.CUDNN_FLASH_ATTENTION: "ort:cudnn", + SdpaKernel.MATH: "ort:math", + } + assert attention_kernel in kernel_names + return kernel_names[attention_kernel] - if (os.getenv("ORT_DISABLE_FUSED_CROSS_ATTENTION") != "1" and config.kv_sequence_length <= 128) or ( - os.getenv("ORT_DISABLE_FUSED_ATTENTION") != "1" - and (config.sequence_length <= 384 or os.getenv("ORT_DISABLE_TRT_FLASH_ATTENTION") != "1") - ): - return "TRT" - if os.getenv("ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION") != "1": - return "MemEff" +def get_cpu_kernel_name(config: MultiHeadAttentionConfig) -> str: + # CPU Flash Attention does not support causal and kv cache etc. + if not (config.causal or config.use_kv_cache or config.past_sequence_length > 0): + if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": + return "ort:flash" - return "Unfused" + return "ort:math" -def get_cpu_kernel_name() -> str: - if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": - return "CPU:Flash" - return "CPU:Unfused" +# ------------------------------------------------------------------ +# Functions for benchmarking PyTorch SDPA +# ------------------------------------------------------------------ +def benchmark_torch_function(func: Callable, *args, **kwargs) -> float: + warmup = 5 + repeats = 100 + for _ in range(warmup): + func(*args, **kwargs) + timer = benchmark.Timer( + stmt="func(*args, **kwargs)", + globals={"args": args, "kwargs": kwargs, "func": func}, + ) + + return timer.timeit(number=repeats).median -def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repeats: int = 100): - if use_gpu: - device_id = torch.cuda.current_device() - device = torch.device("cuda", device_id) - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H] - provider = "CUDAExecutionProvider" - print(f"enable_cuda_graph={enable_cuda_graph}") - else: - device_id = 0 - device = torch.device("cpu") - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] - enable_cuda_graph = False - provider = "CPUExecutionProvider" +def run_torch_sdpa( + batch_size: int, + q_seq_len: int, + kv_seq_len: int, + num_heads: int, + head_size: int, + causal: bool, + device, + dtype, + has_mask: bool = False, + mask_dim: int = 2, + mask_dtype=torch.bool, + backend: Optional[int] = None, +): + q_shape = (batch_size, num_heads, q_seq_len, head_size) + kv_shape = (batch_size, num_heads, kv_seq_len, head_size) + q = torch.randn(q_shape, device=device, dtype=dtype) + k = torch.randn(kv_shape, device=device, dtype=dtype) + v = torch.randn(kv_shape, device=device, dtype=dtype) + + attn_mask = None + if has_mask: + mask_shape = (batch_size, num_heads, q_seq_len, kv_seq_len) if mask_dim == 4 else (q_seq_len, kv_seq_len) + attn_mask = torch.ones(mask_shape, dtype=mask_dtype, device=device) + + context = sdpa_kernel(backend) if backend is not None else nullcontext() + + with context: + average_latency = benchmark_torch_function( + scaled_dot_product_attention, + q, + k, + v, + is_causal=causal, + attn_mask=attn_mask, + ) + return average_latency + + +def get_test_configs(use_gpu: bool = True): if use_gpu: # (batch_size, sequence_length, past_sequence_length, num_heads, head_size, run_unfused) configs = [ @@ -450,31 +588,70 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea ] else: configs = [ + # TNLGv4 (1, 128, 0, 32, 128, True), (1, 256, 0, 32, 128, True), (1, 512, 0, 32, 128, True), (1, 1024, 0, 32, 128, True), - (1, 2048, 0, 32, 128, True), + # (1, 2048, 0, 32, 128, True), + # bert-base + (1, 128, 0, 12, 64, True), + (1, 384, 0, 12, 64, True), + (1, 512, 0, 12, 64, True), + (4, 128, 0, 12, 64, True), + (4, 384, 0, 12, 64, True), + (4, 512, 0, 12, 64, True), + # bert-large + (1, 128, 0, 16, 64, True), + (1, 384, 0, 16, 64, True), + (1, 512, 0, 16, 64, True), + (4, 128, 0, 16, 64, True), + (4, 384, 0, 16, 64, True), + (4, 512, 0, 16, 64, True), ] + return configs - # List of environment variables to enable/disable attention kernels - print("Environment Variables:") - env_names = [ - "ORT_DISABLE_FLASH_ATTENTION", - "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV", - "ORT_DISABLE_FUSED_ATTENTION", - "ORT_DISABLE_TRT_FLASH_ATTENTION", - "ORT_ENABLE_FUSED_CAUSAL_ATTENTION", - "ORT_DISABLE_FUSED_CROSS_ATTENTION", - "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", - ] - for name in env_names: - value = os.getenv(name) - if value is not None: - print(f"{name}={value}") - print("\nformat\tcausal\tbatch\tseqlen\theads\th_dim\tms\tTFLOPS\tkernel") - causal = False +def get_compute_capability(): + assert torch.cuda.is_available() + major, minor = torch.cuda.get_device_capability() + sm = major * 10 + minor + return sm + + +def run_tflops_test( + csv_writer: csv.DictWriter, + use_gpu: bool = True, + enable_cuda_graph: bool = False, + causal: bool = False, + has_past: bool = False, + intra_op_num_threads: int = 0, + repeats: int = 100, +): + print(f"run_tflops_test: causal={causal}") + + if use_gpu: + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H] + provider = "CUDAExecutionProvider" + # flash attention is available for sm >= 80 + sm = get_compute_capability() + if sm >= 80: + backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION] + else: + backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION] + else: + device_id = 0 + device = torch.device("cpu") + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] + enable_cuda_graph = False + provider = "CPUExecutionProvider" + backends = [SdpaKernel.DEFAULT] + + configs = get_test_configs(use_gpu) + + print("\nformat\tcausal\tprompt\tbatch\tseqlen\theads\th_dim\tthreads\tms\tTFLOPS\tkernel") for input_format in formats: for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: @@ -496,21 +673,27 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea share_past_present_buffer=False, input_format=input_format, ) - - session = create_session(config) + for attention_kernel in backends: + sess_options = SessionOptions() + sess_options.intra_op_num_threads = intra_op_num_threads + session = create_session(config, sess_options, attention_kernel=attention_kernel) if use_gpu: - kernel = get_gpu_kernel_name(config) + kernel = get_gpu_kernel_name(attention_kernel) else: - kernel = get_cpu_kernel_name() + kernel = get_cpu_kernel_name(config) - if kernel == "Unfused": + if "math" in kernel: # Skip large sequence length for Unfused kernel to avoid OOM. if not enable_unfused: + if config.verbose: + print(f"skip unfused kernel for {vars(config)}") continue # Unfused kernel does not support packed QKV or packed KV formats. if input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]: + if config.verbose: + print(f"skip input_format for {vars(config)}") continue input_dict = config.random_inputs() @@ -526,19 +709,168 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea del session + format_str = InputFormats.input_format_str(input_format) + # compute TFLOPS per second - speed = tflops_per_second( - flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency - ) + speed = None + if past_sequence_length == 0: + speed = tflops_per_second( + flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency + ) + + row = { + "use_gpu": use_gpu, + "enable_cuda_graph": enable_cuda_graph, + "format": format_str, + "causal": causal, + "batch_size": batch_size, + "sequence_length": sequence_length, + "past_sequence_length": past_sequence_length, + "num_heads": num_heads, + "head_size": head_size, + "intra_op_num_threads": intra_op_num_threads, + "average_latency": average_latency, + "tflops": speed, + "kernel": kernel, + } + csv_writer.writerow(row) - format = InputFormats.input_format_str(input_format) + speed = f"{speed:.2f}" if speed is not None else "NA" print( - f"{format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t{average_latency * 1000:.2f}\t{speed:.2f}\t{kernel}" + f"{format_str}\t{causal}\t{not has_past}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" + f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{kernel}" ) +def run_torch_test( + csv_writer: csv.DictWriter, + use_gpu: bool = True, + causal: bool = False, +): + configs = get_test_configs(use_gpu) + + if use_gpu: + if not torch.cuda.is_available(): + return + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + dtype = torch.float16 + backends = [ + None, + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.CUDNN_ATTENTION, + SDPBackend.MATH, + ] + else: + device = torch.device("cpu") + dtype = torch.float32 + backends = [None] + + backend_names = { + SDPBackend.FLASH_ATTENTION: "torch:flash", + SDPBackend.EFFICIENT_ATTENTION: "torch:efficient", + SDPBackend.CUDNN_ATTENTION: "torch:cudnn", + SDPBackend.MATH: "torch:math", + None: "torch:default", + } + + # Test PyTorch latency + for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: + for backend in backends: + if backend == SDPBackend.MATH and not enable_unfused: + continue + if backend == SDPBackend.FLASH_ATTENTION and platform.system() != "Linux": + continue + + backend_name = backend_names[backend] + try: + with torch.no_grad(): + torch_latency = run_torch_sdpa( + batch_size, + sequence_length, + sequence_length, + num_heads, + head_size, + causal, + has_mask=False, + mask_dim=2, + mask_dtype=torch.bool, + device=device, + dtype=dtype, + backend=backend, + ) + except RuntimeError: + continue + + speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), torch_latency) + input_format = "Q,K,V" + print( + f"{input_format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" + f"{0}\t{torch_latency * 1000:.2f}\t{speed:.2f}\t{backend_name}" + ) + row = { + "use_gpu": use_gpu, + "enable_cuda_graph": False, + "format": input_format, + "causal": causal, + "batch_size": batch_size, + "sequence_length": sequence_length, + "past_sequence_length": past_sequence_length, + "num_heads": num_heads, + "head_size": head_size, + "intra_op_num_threads": torch.get_num_threads(), + "average_latency": torch_latency, + "tflops": speed, + "kernel": backend_name, + } + csv_writer.writerow(row) + + +def run_tflops_tests(args): + features = "gpu" if args.use_gpu else "cpu" + if args.causal: + features += "_causal" + if args.has_past: + features += "_past" + csv_filename = "benchmark_mha_{}_{}_{}.csv".format( + features, + "torch" if args.torch else "ort", + datetime.now().strftime("%Y%m%d-%H%M%S"), + ) + with open(csv_filename, mode="a", newline="") as csv_file: + column_names = [ + "use_gpu", + "enable_cuda_graph", + "format", + "causal", + "batch_size", + "sequence_length", + "past_sequence_length", + "num_heads", + "head_size", + "intra_op_num_threads", + "average_latency", + "tflops", + "kernel", + ] + csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) + csv_writer.writeheader() + + if args.torch: + run_torch_test(csv_writer, args.use_gpu, args.causal) + else: + run_tflops_test( + csv_writer, + use_gpu=args.use_gpu, + enable_cuda_graph=args.use_cuda_graph, + causal=args.causal, + has_past=args.has_past, + intra_op_num_threads=args.intra_op_num_threads, + ) + + def plot_prompt_performance( - sm: int, model_name: str, batch_size: int, num_heads: int, @@ -558,6 +890,7 @@ def plot_prompt_performance( "styles": [("red", "solid"), ("yellow", "dashdot"), ("blue", "dashed"), ("green", "dotted")][0 : len(formats)], } + sm = get_compute_capability() configs = [ triton.testing.Benchmark( x_names=["sequence_length"], @@ -591,13 +924,14 @@ def benchmark( sequence_length=sequence_length, num_heads=num_heads, head_size=head_size, - causal=True, + causal=False, past_sequence_length=0, kv_sequence_length=sequence_length if input_format == InputFormats.get_name_list()[-1] else None, max_cache_sequence_length=max_seq_len, provider="CUDAExecutionProvider", enable_cuda_graph=False, device=device, + dtype=torch.float16, use_kv_cache=False, input_format=InputFormats.convert(input_format), ) @@ -609,14 +943,14 @@ def benchmark( benchmark.run(save_path=".", print_data=True) -def run_performance_test(sm: int): +def run_bert_performance_test(): """ Run performance tests for prompt and token generation. """ configures = [ - (1, 32, 128, 8192, "TNLGv4"), - (4, 32, 128, 8192, "TNLGv4"), + # (1, 32, 128, 8192, "TNLGv4"), + # (4, 32, 128, 8192, "TNLGv4"), (1, 12, 64, 1024, "BertBase"), (16, 12, 64, 1024, "BertBase"), (1, 16, 64, 1024, "BertLarge"), @@ -625,7 +959,6 @@ def run_performance_test(sm: int): for batch_size, num_heads, head_size, max_seq_len, model_name in configures: plot_prompt_performance( - sm=sm, batch_size=batch_size, num_heads=num_heads, head_size=head_size, @@ -634,18 +967,84 @@ def run_performance_test(sm: int): ) +def _parse_arguments(): + parser = argparse.ArgumentParser(description="Benchmark MultiHeadAttention for ONNX Runtime and PyTorch.") + + parser.add_argument( + "--use_gpu", + required=False, + action="store_true", + help="Use GPU for inference.", + ) + parser.set_defaults(use_gpu=False) + + parser.add_argument( + "--use_cuda_graph", + required=False, + action="store_true", + help="Use cuda graph in onnxruntime.", + ) + parser.set_defaults(use_cuda_graph=False) + + parser.add_argument( + "--intra_op_num_threads", + required=False, + type=int, + choices=[0, 1, 2, 4, 8, 16], + default=0, + help="intra_op_num_threads for onnxruntime. ", + ) + + parser.add_argument( + "--has_past", + required=False, + action="store_true", + help="whether past_sequence_length > 0", + ) + parser.set_defaults(has_past=False) + + parser.add_argument( + "--causal", + required=False, + action="store_true", + help="test unidirectional", + ) + parser.set_defaults(causal=False) + + parser.add_argument( + "--torch", + required=False, + action="store_true", + help="test pytorch instead of onnxruntime", + ) + parser.set_defaults(torch=False) + + args = parser.parse_args() + + return args + + if __name__ == "__main__": - if torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers(): - # Test CUDA provider - major, minor = torch.cuda.get_device_capability() - sm = major * 10 + minor + args = _parse_arguments() + print(f"arguments:{args}") + + if args.has_past: + assert args.causal, "--has_past need --causal specified" + + if args.use_gpu: + assert args.torch or not args.causal, "no causal cuda kernel in MHA op" + assert torch.cuda.is_available() + if not args.torch: + assert "CUDAExecutionProvider" in get_available_providers() + if args.torch: + assert Version(torch.__version__) >= Version("2.3.0") + assert args.has_past is False + + if args.use_gpu and not args.torch: if platform.system() == "Linux": s = torch.cuda.Stream() with torch.cuda.stream(s), torch.no_grad(): - run_performance_test(sm) - - run_tflops_test(use_gpu=True, enable_cuda_graph=True) + run_bert_performance_test() - # Test CPU provider - run_tflops_test(use_gpu=False, enable_cuda_graph=False) + run_tflops_tests(args) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.sh b/onnxruntime/test/python/transformers/benchmark_mha.sh index 7b21cf1cc1e0..613543d0172d 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.sh +++ b/onnxruntime/test/python/transformers/benchmark_mha.sh @@ -1,14 +1,40 @@ -echo "flash attention v2" -ORT_DISABLE_FLASH_ATTENTION=0 ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV=0 python benchmark_mha.py | tee result.txt +#!/bin/sh -echo "===" -echo "TensorRT attention kernels - cross attention (when kv_seq_len <= 128) or fused attention (when seq_len <= 384) or flash attention (seq_len > 384)" -ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py | tee -a result.txt +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- -echo "===" -echo "Memory Efficient attention" -ORT_DISABLE_FLASH_ATTENTION=1 ORT_DISABLE_TRT_FLASH_ATTENTION=1 ORT_DISABLE_FUSED_ATTENTION=1 ORT_DISABLE_FUSED_CROSS_ATTENTION=1 python benchmark_mha.py | tee -a result.txt +echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" -echo "===" -echo "Unfused Attention (some configurations might fail)" -ORT_DISABLE_FLASH_ATTENTION=1 ORT_DISABLE_TRT_FLASH_ATTENTION=1 ORT_DISABLE_FUSED_ATTENTION=1 ORT_DISABLE_FUSED_CROSS_ATTENTION=1 ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION=1 python benchmark_mha.py | tee -a result.txt +export CUDA_VISIBLE_DEVICES=0 +python benchmark_mha.py --use_gpu +python benchmark_mha.py --use_gpu --use_cuda_graph +python benchmark_mha.py --use_gpu --torch + +cat benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv + +echo "Benchmark performance on CPU with number of threads:" +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=1 python benchmark_mha.py --torch +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=2 python benchmark_mha.py --torch +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=4 python benchmark_mha.py --torch +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=8 python benchmark_mha.py --torch + +python benchmark_mha.py --intra_op_num_threads 1 +python benchmark_mha.py --intra_op_num_threads 2 +python benchmark_mha.py --intra_op_num_threads 4 +python benchmark_mha.py --intra_op_num_threads 8 + + +echo "Benchmark performance on CPU with default threads settings:" +python benchmark_mha.py +ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py +python benchmark_mha.py --torch + +python benchmark_mha.py --causal +python benchmark_mha.py --torch --causal + +# Pytorch SDPA does not support causal attention with past state, we only test ORT here. +python benchmark_mha.py --causal --has_past + +cat benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index ff473cc2ced9..a35d02b0b9d5 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -10,36 +10,56 @@ import concurrent.futures import itertools import unittest -from enum import IntEnum from typing import Dict, List, Optional import numpy import torch -from benchmark_mha import ( - InputFormats, - MultiHeadAttentionConfig, - OrtMultiHeadAttention, - create_multi_head_attention_onnx_model, -) +from benchmark_mha import InputFormats, MultiHeadAttentionConfig, OrtMultiHeadAttention, SdpaKernel, create_ort_session from einops import rearrange from parameterized import parameterized import onnxruntime -from onnxruntime import InferenceSession -class SdpaKernel(IntEnum): - """Bit flags for sdpa_kernel CUDA provider option""" +def get_provider_support_info(provider: str, use_kv_cache: bool): + if provider == "CUDAExecutionProvider": + if not use_kv_cache: + formats = [ + InputFormats.Q_K_V_BSNH_BSNH_BSNH, + InputFormats.Q_KV_BSNH_BSN2H, + InputFormats.QKV_BSN3H, + InputFormats.Q_K_V_BSNH_BNSH_BNSH, + ] + else: + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] + + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + dtype = torch.float16 + else: + assert provider == "CPUExecutionProvider" + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] + if not use_kv_cache: + formats.append(InputFormats.Q_K_V_BSNH_BNSH_BNSH) + device = torch.device("cpu") + dtype = torch.float + return device, dtype, formats + + +def get_bias_support(format: InputFormats): + if format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + return [True, False] + + if format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: + return [True, False] + + if format == InputFormats.Q_KV_BSNH_BSN2H: + return [False] + + if format == InputFormats.QKV_BSN3H: + return [True, False] - DEFAULT = 0 - FLASH_ATTENTION = 1 - EFFICIENT_ATTENTION = 2 - TRT_FUSED_ATTENTION = 4 - CUDNN_FLASH_ATTENTION = 8 - MATH = 16 - TRT_FLASH_ATTENTION = 32 - TRT_CROSS_ATTENTION = 64 - TRT_CAUSAL_ATTENTION = 128 + raise RuntimeError(f"Unknown format: {format}") def attention_reference( @@ -105,8 +125,8 @@ def attention_reference( def mha_with_past_reference( config: MultiHeadAttentionConfig, - past_k: torch.Tensor, - past_v: torch.Tensor, + past_k: Optional[torch.Tensor], + past_v: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -115,41 +135,23 @@ def mha_with_past_reference( ): assert config.kv_sequence_length == config.sequence_length assert config.use_kv_cache - assert past_k.dim() == 4 and k.dim() == 4 and past_k.size(1) == k.size(1) # both BNSH format - assert past_v.dim() == 4 and v.dim() == 4 and past_v.size(1) == v.size(1) # both BNSH format - - present_k = torch.cat((past_k, k), dim=2) - present_v = torch.cat((past_v, v), dim=2) + if past_k is not None: + assert ( + past_k.dim() == 4 and k.dim() == 4 and past_k.size(1) == k.size(1) + ), f"expect BNSH format: {past_k.shape=} {k.shape=}" + + if past_v is not None: + assert ( + past_v.dim() == 4 and v.dim() == 4 and past_v.size(1) == v.size(1) + ), f"expect BNSH format: {past_v.shape=} {v.shape=}" + + present_k = torch.cat((past_k, k), dim=2) if past_k is not None else k + present_v = torch.cat((past_v, v), dim=2) if past_v is not None else v out = attention_reference(config.head_size, q, present_k, present_v, scale=scale, mask=mask) return out, present_k, present_v -def get_provider_support_info(provider: str, use_kv_cache: bool): - if provider == "CUDAExecutionProvider": - if not use_kv_cache: - formats = [ - InputFormats.Q_K_V_BSNH_BSNH_BSNH, - InputFormats.Q_KV_BSNH_BSN2H, - InputFormats.QKV_BSN3H, - InputFormats.Q_K_V_BSNH_BNSH_BNSH, - ] - else: - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] - - device_id = torch.cuda.current_device() - device = torch.device("cuda", device_id) - dtype = torch.float16 - else: - assert provider == "CPUExecutionProvider" - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] - if not use_kv_cache: - formats.append(InputFormats.Q_K_V_BSNH_BSNH_BSNH) - device = torch.device("cpu") - dtype = torch.float - return device, dtype, formats - - def get_compute_capability(): if torch.cuda.is_available() and "CUDAExecutionProvider" in onnxruntime.get_available_providers(): major, minor = torch.cuda.get_device_capability() @@ -164,35 +166,38 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): yield batch_sizes = [1, 2, 3] - sequence_lengths = [1, 16, 127, 128, 255, 256, 383, 384, 2048] + sequence_lengths = [1, 16, 127, 128, 255, 256, 383, 384, 512] heads = [1, 3, 4, 16] head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] device, dtype, formats = get_provider_support_info(provider, False) if comprehensive: + sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory for batch_size in batch_sizes: for sequence_length in sequence_lengths: for num_heads in heads: for head_size in head_sizes: for format in formats: for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=0, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=False, - share_past_present_buffer=False, - input_format=format, - ) - yield config + for has_bias in get_bias_support(format): + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=0, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=False, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config else: test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): @@ -200,25 +205,27 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): sequence_length = sequence_lengths[i % len(sequence_lengths)] num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] - format = formats[i % len(formats)] for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=0, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=False, - share_past_present_buffer=False, - input_format=format, - ) - yield config + for format in formats: + for has_bias in get_bias_support(format): + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=0, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=False, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config def kv_cache_test_cases(provider: str, comprehensive: bool): @@ -227,37 +234,42 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): yield batch_sizes = [1, 2, 3] - sequence_lengths = [1, 15, 16, 255, 256, 2048] + sequence_lengths = [1, 15, 16, 255, 256, 512] heads = [1, 3, 4, 16] head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - - sequence_length = 1 device, dtype, formats = get_provider_support_info(provider, True) if comprehensive: + sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory for batch_size in batch_sizes: for past_sequence_length in sequence_lengths: for num_heads in heads: for head_size in head_sizes: for format in formats: for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=past_sequence_length, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=True, - share_past_present_buffer=False, - input_format=format, - ) - yield config + for has_past_input in [True, False]: + for has_bias in get_bias_support(format): + sequence_length = 1 if has_past_input else past_sequence_length + past_seq_len = past_sequence_length if has_past_input else 0 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=past_seq_len, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=has_past_input, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config else: test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): @@ -265,31 +277,31 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): past_sequence_length = sequence_lengths[i % len(sequence_lengths)] num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] - format = formats[i % len(formats)] for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=past_sequence_length, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=True, - share_past_present_buffer=False, - input_format=format, - ) - yield config - - -def mha_test_cases(provider: str, comprehensive: bool): - return itertools.chain( - no_kv_cache_test_cases(provider, comprehensive), kv_cache_test_cases(provider, comprehensive) - ) + for format in formats: + for has_past_input in [True, False]: + for has_bias in get_bias_support(format): + sequence_length = 1 if has_past_input else past_sequence_length + past_seq_len = past_sequence_length if has_past_input else 0 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=past_seq_len, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=has_past_input, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): @@ -364,6 +376,7 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): device=device, dtype=dtype, use_kv_cache=True, + has_past_input=True, share_past_present_buffer=False, input_format=format, ) @@ -371,13 +384,6 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): yield configs -def multi_thread_test_cases(provider: str, comprehensive: bool): - return itertools.chain( - no_kv_cache_multi_thread_test_cases(provider, comprehensive), - kv_cache_multi_thread_test_cases(provider, comprehensive), - ) - - def causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) @@ -395,28 +401,31 @@ def parity_check_mha( if config.causal and config.provider == "CUDAExecutionProvider": return - ort_mha = OrtMultiHeadAttention(config) + ort_mha = OrtMultiHeadAttention(config, use_tf32=False) ort_outputs = ort_mha.infer() out = ort_outputs["output"] out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + no_bias_k_v = config.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH config.input_format = InputFormats.Q_K_V_BSNH_BSNH_BSNH - ref_inputs = config.random_inputs() - q = ( - ref_inputs["query"] - .reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) - ) - k = ( - ref_inputs["key"] - .reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) - ) - v = ( - ref_inputs["value"] - .reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) - ) + ref_inputs = config.random_inputs(no_bias_k_v=no_bias_k_v) + q = ref_inputs["query"].reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + k = ref_inputs["key"].reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) + v = ref_inputs["value"].reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) + + if "bias" in ref_inputs: + bias = ref_inputs["bias"] + bias = bias.reshape((3, config.num_heads, config.head_size)) + bias_q = bias[0, :, :].reshape(1, 1, config.num_heads, config.head_size) + bias_k = bias[1, :, :].reshape(1, 1, config.num_heads, config.head_size) + bias_v = bias[2, :, :].reshape(1, 1, config.num_heads, config.head_size) + q = q + bias_q + k = k + bias_k + v = v + bias_v + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) mask = None if config.causal: @@ -425,8 +434,8 @@ def parity_check_mha( k_cache = None v_cache = None if config.use_kv_cache: - past_k = ref_inputs["past_key"] - past_v = ref_inputs["past_value"] + past_k = ref_inputs.get("past_key", None) + past_v = ref_inputs.get("past_value", None) out_ref, k_cache, v_cache = mha_with_past_reference(config, past_k, past_v, q, k, v, mask=mask) else: out_ref = attention_reference(config.head_size, q, k, v, mask=mask) @@ -466,7 +475,7 @@ def parity_check_mha_multi_threading( test_inputs: List[Dict], rtol: float = 1e-3, atol: float = 1e-3, - sdpa_kernel: int = SdpaKernel.DEFAULT, + attention_kernel=SdpaKernel.DEFAULT, max_threads: int = 5, verbose: bool = False, ): @@ -475,22 +484,16 @@ def parity_check_mha_multi_threading( # For now, MHA CUDA kernel does not support causal so skip such test cases. if config.causal and config.provider == "CUDAExecutionProvider": return None + # Some kernel does not support certain input format. - if sdpa_kernel not in [ + if attention_kernel not in [ SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION, ] and config.input_format in [InputFormats.Q_KV_BSNH_BSN2H]: return None - if verbose: - print(f"create a shared session with {vars(config)}") - onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=True) - if config.provider == "CUDAExecutionProvider": - provider_options = {"arena_extend_strategy": "kSameAsRequested", "sdpa_kernel": int(sdpa_kernel)} - providers = [(config.provider, provider_options), "CPUExecutionProvider"] - else: - providers = ["CPUExecutionProvider"] - ort_session = InferenceSession(onnx_model_str, providers=providers) + + ort_session = create_ort_session(config, attention_kernel=attention_kernel, use_symbolic_shape=True, use_tf32=False) def convert_to_ort_inputs(feed_dict): ort_inputs = {} @@ -600,20 +603,34 @@ def check_parity_with_config(i: int): return None -# Do not run too many tests in CI pipeline. Change it to True to run all combinations in dev machine. +def mha_test_cases(provider: str, comprehensive: bool): + return itertools.chain( + no_kv_cache_test_cases(provider, comprehensive), + kv_cache_test_cases(provider, comprehensive), + ) + + +def multi_thread_test_cases(provider: str, comprehensive: bool): + return itertools.chain( + no_kv_cache_multi_thread_test_cases(provider, comprehensive), + kv_cache_multi_thread_test_cases(provider, comprehensive), + ) + + +# Off by default so that we do not run too many tests in CI pipeline. comprehensive_mode = False class TestMultiHeadAttention(unittest.TestCase): @parameterized.expand(mha_test_cases("CUDAExecutionProvider", comprehensive_mode), skip_on_empty=True) def test_mha_cuda(self, config): - parity_check_mha(config) + parity_check_mha(config, rtol=5e-3, atol=5e-3) @parameterized.expand(mha_test_cases("CPUExecutionProvider", comprehensive_mode), skip_on_empty=True) def test_mha_cpu(self, config): - parity_check_mha(config) + parity_check_mha(config, rtol=5e-3, atol=5e-3) - def run_mha_cuda_multi_threading(self, spda_kernel): + def run_mha_cuda_multi_threading(self, attention_kernel): for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode): test_inputs = [] for config in configs: @@ -626,23 +643,30 @@ def run_mha_cuda_multi_threading(self, spda_kernel): config.input_format = old_format test_inputs.append({"config": config, "ort_inputs": ort_inputs, "ref_inputs": ref_inputs}) - exception = parity_check_mha_multi_threading(test_inputs, sdpa_kernel=spda_kernel, max_threads=len(configs)) - assert exception is None, f"{spda_kernel=}, {vars(configs[0])}, {exception}" + exception = parity_check_mha_multi_threading( + test_inputs, attention_kernel=attention_kernel, max_threads=len(configs) + ) + assert exception is None, f"{attention_kernel=}, {vars(configs[0])}, {exception}" def test_mha_cuda_multi_threading(self): - self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT) + if get_compute_capability() >= 60: + self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT) def test_mha_cuda_multi_threading_efficient(self): - self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION) + if comprehensive_mode and get_compute_capability() >= 60: + self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION) + + def test_mha_cuda_multi_threading_math(self): + if comprehensive_mode and get_compute_capability() >= 60: + self.run_mha_cuda_multi_threading(SdpaKernel.MATH) def test_mha_cuda_multi_threading_trt(self): - sm = get_compute_capability() - if sm in [75, 80, 86, 89]: + if get_compute_capability() in [75, 80, 86, 89]: self.run_mha_cuda_multi_threading( SdpaKernel.TRT_FUSED_ATTENTION | SdpaKernel.TRT_FLASH_ATTENTION - | SdpaKernel.TRT_CROSS_ATTENTION | SdpaKernel.TRT_CAUSAL_ATTENTION + | SdpaKernel.TRT_CROSS_ATTENTION ) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 52491a179c2c..7a33bf8a527c 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1959,7 +1959,9 @@ TEST(CApiTest, get_allocator_cpu) { #ifdef USE_CUDA TEST(CApiTest, get_allocator_cuda) { Ort::SessionOptions session_options; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); @@ -2076,7 +2078,9 @@ TEST(CApiTest, io_binding_cuda) { #ifdef USE_TENSORRT Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_options, 0)); #else - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); #endif Ort::Session session(*ort_env, MODEL_URI, session_options); @@ -3438,7 +3442,9 @@ TEST(CApiTest, AllocateInitializersFromNonArenaMemory) { Ort::SessionOptions session_options; #ifdef USE_CUDA - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); #else // arena is enabled but the sole initializer will still be allocated from non-arena memory Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); @@ -3890,7 +3896,9 @@ TEST(CApiTest, GitHubIssue10179) { try { const auto* model_path = MODEL_URI; Ort::SessionOptions session_options{}; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session{*ort_env, model_path, session_options}; } catch (const std::exception& e) { std::cerr << "exception: " << e.what() << "\n"; @@ -3920,7 +3928,9 @@ TEST(CApiTest, GitHubIssue10179) { TEST(CApiTest, TestCudaMemcpyToHostWithSequenceTensors) { const auto* model_path = SEQUENCE_MODEL_URI_2; Ort::SessionOptions session_options{}; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session{*ort_env, model_path, session_options}; Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); diff --git a/onnxruntime/test/shared_lib/test_model_loading.cc b/onnxruntime/test/shared_lib/test_model_loading.cc index b7f6f7f4b9a7..5694398b9cb1 100644 --- a/onnxruntime/test/shared_lib/test_model_loading.cc +++ b/onnxruntime/test/shared_lib/test_model_loading.cc @@ -60,8 +60,9 @@ TEST(CApiTest, model_from_array) { create_session(so); #ifdef USE_CUDA - // test with CUDA provider when using onnxruntime as dll - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(so, 0)); + OrtCUDAProviderOptionsV2* options; + Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + so.AppendExecutionProvider_CUDA_V2(*options); create_session(so); #endif } diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 1885a213bdf3..4b14d50127aa 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -720,7 +720,10 @@ "^test_constantofshape_int_zeros", "^test_reduce_log_sum_empty_set_cpu", "^test_reduce_log_sum_exp_empty_set_cpu", - "^test_reduce_prod_empty_set_cpu" + "^test_reduce_prod_empty_set_cpu", + //Bug: DML EP does not execute operators with an empty input tensor + //TODO: Resolve as a graph implementation that returns a constant inf tensor with appropriate strides + "^test_reduce_min_empty_set_cpu" ], // ORT first supported opset 7, so models with nodes that require versions prior to opset 7 are not supported "tests_with_pre_opset7_dependencies": [ diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 2a8d2de982e7..92f803030ada 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -181,6 +181,64 @@ static void propagateRecvOutputTensorElemTypes( } } +void SendShapeInfer(ONNX_NAMESPACE::InferenceContext& ctx) { + if (ctx.getNumInputs() < 3) { + fail_shape_inference("Send must have at least three inputs."); + } else { + if (hasInputShape(ctx, 0)) { + auto& signal_input_shape = getInputShape(ctx, 0); + if (static_cast(signal_input_shape.dim_size()) != 0) { + fail_shape_inference("InputSignal of Send must be a scalar."); + } + } + if (hasInputShape(ctx, 1)) { + auto& remote_input_shape = getInputShape(ctx, 1); + if (static_cast(remote_input_shape.dim_size()) != 0) { + fail_shape_inference("Remote of Send must be a scalar."); + } + } + + checkSendInputTensorElemTypes(ctx, "element_types", ctx.getNumInputs() - 2); + } + + if (ctx.getNumOutputs() != 1) { + fail_shape_inference("Send must have one output."); + } + + auto output_element_type = ctx.getOutputType(0)->mutable_tensor_type(); + output_element_type->set_elem_type(TensorProto::BOOL); + ONNX_NAMESPACE::TensorShapeProto output_shape; + updateOutputShape(ctx, 0, {}); + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); +} + +void RecvShapeInfer(ONNX_NAMESPACE::InferenceContext& ctx) { + if (ctx.getNumInputs() != 2) { + fail_shape_inference("Recv must have two inputs."); + } else { + if (hasInputShape(ctx, 0)) { + auto& signal_input_shape = getInputShape(ctx, 0); + if (static_cast(signal_input_shape.dim_size()) != 0) { + fail_shape_inference("InputSignal of Recv must be a scalar."); + } + } + if (hasInputShape(ctx, 1)) { + auto& remote_input_shape = getInputShape(ctx, 1); + if (static_cast(remote_input_shape.dim_size()) != 0) { + fail_shape_inference("Remote of Recv must be a scalar."); + } + } + } + + if (ctx.getNumOutputs() < 2) { + fail_shape_inference("Recv must have at least two outputs."); + } + + updateOutputShape(ctx, 0, {}); + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); + propagateRecvOutputTensorElemTypes(ctx, "element_types", ctx.getNumOutputs() - 1); +} + TensorProto ToDimensionOneFloatTensor(float value) { auto t = ToTensor(std::vector({value})); t.add_dims(1); @@ -3388,30 +3446,7 @@ Return true if all elements are true and false otherwise. "Constrain types to boolean tensors.") .TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - if (ctx.getNumInputs() < 3) { - fail_shape_inference("Send must have at least three inputs."); - } else { - auto& signal_input_shape = getInputShape(ctx, 0); - if (static_cast(signal_input_shape.dim_size()) != 0) { - fail_shape_inference("InputSignal of Send must be a scalar."); - } - auto& remote_input_shape = getInputShape(ctx, 1); - if (static_cast(remote_input_shape.dim_size()) != 0) { - fail_shape_inference("Remote of Send must be a scalar."); - } - - checkSendInputTensorElemTypes(ctx, "element_types", ctx.getNumInputs() - 2); - } - - if (ctx.getNumOutputs() != 1) { - fail_shape_inference("Send must have one output."); - } - - auto output_element_type = ctx.getOutputType(0)->mutable_tensor_type(); - output_element_type->set_elem_type(TensorProto::BOOL); - ONNX_NAMESPACE::TensorShapeProto output_shape; - updateOutputShape(ctx, 0, {}); - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); + SendShapeInfer(ctx); }); ONNX_CONTRIB_OPERATOR_SCHEMA(Recv) @@ -3437,26 +3472,7 @@ Return true if all elements are true and false otherwise. "Constrain types to boolean tensors.") .TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - if (ctx.getNumInputs() != 2) { - fail_shape_inference("Recv must have two inputs."); - } else { - auto& signal_input_shape = getInputShape(ctx, 0); - if (static_cast(signal_input_shape.dim_size()) != 0) { - fail_shape_inference("InputSignal of Recv must be a scalar."); - } - auto& remote_input_shape = getInputShape(ctx, 1); - if (static_cast(remote_input_shape.dim_size()) != 0) { - fail_shape_inference("Remote of Recv must be a scalar."); - } - } - - if (ctx.getNumOutputs() < 2) { - fail_shape_inference("Recv must have at least two outputs."); - } - - updateOutputShape(ctx, 0, {}); - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); - propagateRecvOutputTensorElemTypes(ctx, "element_types", ctx.getNumOutputs() - 1); + RecvShapeInfer(ctx); }); ONNX_CONTRIB_OPERATOR_SCHEMA(MegatronF) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 8d110c692751..1135ef41cfc4 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -67,410 +67,422 @@ using OpsetToIgnorableIndicesMap = InlinedHashMap; * or not. * 3. Some ops are not supported in older opsets, we need to check whether it is applicable to recompute or not. */ -const InlinedHashMap& GetAllowedRecomputeOps(int probe_op_level) { - static InlinedHashMap> recomputable_op_table_map; - if (recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end()) { - return recomputable_op_table_map.at(probe_op_level); - } +InlinedHashMap> InitializeRecomputableOpTable() { + InlinedHashMap> recomputable_op_table_map; + + constexpr const int basic_op_level = static_cast(ProbeLevel::Basic); + recomputable_op_table_map.insert({basic_op_level, InlinedHashMap()}); + auto& basic_recomputable_op_table = recomputable_op_table_map.at(basic_op_level); + + basic_recomputable_op_table.insert({ + { + utils::GetFullQualifiedOpName("Add", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BatchNormalization", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {9, {}}, + {14, {}}, + {15, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasDropout", kMSDomain), + { + {1, {3, 4}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain), + { + {1, {3, 4}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("BitmaskDropout", kMSDomain), + { + {1, {1, 2}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Cast", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {9, {}}, + {13, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("ConcatTraining", kMSDomain), + { + {1, {}}, + + }, + }, + { + utils::GetFullQualifiedOpName("ConstantOfShape", kOnnxDomain), + { + {9, {0}}, // ignore the `input`, e.g. the shape of the expected output tensor + {20, {0}}, + }, + }, + { + utils::GetFullQualifiedOpName("Cos", kOnnxDomain), + { + {7, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("CumSum", kOnnxDomain), + { + // The axis input is trivial + {11, {1}}, + {14, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), + { + // ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support. + {12, {1, 2}}, // ignore ratio and training_mode + {13, {1, 2}}, + }, + }, + { + utils::GetFullQualifiedOpName("Div", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Einsum", kOnnxDomain), + { + {12, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Equal", kOnnxDomain), + { + {1, {}}, + {7, {}}, + {11, {}}, + {13, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Expand", kOnnxDomain), + { + {8, {1}}, // Ignore the shape. + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("FastGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("FlattenAndUnpad", kMSDomain), + { + {1, {1}}, // ignore the indices + }, + }, + { + utils::GetFullQualifiedOpName("Gather", kOnnxDomain), + { + {1, {1}}, // ignore the indices + {11, {1}}, + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gelu", kOnnxDomain), + { + {20, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gemm", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {9, {}}, + {11, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Less", kOnnxDomain), + { + {1, {}}, + {7, {}}, + {9, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("MemcpyFromHost", kOnnxDomain), + { + {1, {0}}, // Ignore CPU input. + }, + }, + { + utils::GetFullQualifiedOpName("Mul", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Neg", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("NonZero", kOnnxDomain), + { + {9, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("PadAndUnflatten", kMSDomain), + { + {1, {1, 2}}, // ignore the indices and unflatten_dims + }, + }, + { + // Be noted, NOT all PythonOp will be allowed to recompute, there will be further check. + utils::GetFullQualifiedOpName("PythonOp", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Range", kOnnxDomain), + { + {11, {0, 1, 2}}, // ignore start, end, delta, because they are scalars. + }, + }, + { + utils::GetFullQualifiedOpName("Reshape", kOnnxDomain), + { + {1, {}}, + {5, {}}, // ignore the shape. + {13, {}}, + {14, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Sin", kOnnxDomain), + { + {7, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Slice", kOnnxDomain), + { + {1, {}}, + {10, {1, 2, 3, 4}}, // ignore starts, ends, axes (optional) and steps (optional) + {11, {1, 2, 3, 4}}, + {13, {1, 2, 3, 4}}, + }, + }, + { + utils::GetFullQualifiedOpName("Split", kOnnxDomain), + { + {1, {1}}, // ignore split (optional) + {2, {}}, + {11, {}}, + {13, {1}}, // ignore the split (optional) + {18, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Squeeze", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {1}}, // ignore the axes (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Sub", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Tile", kOnnxDomain), + { + {1, {1, 2}}, + {6, {1}}, + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Transpose", kOnnxDomain), + { + {1, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Trilu", kOnnxDomain), + { + {14, {1}}, // ignore k (optional) + }, + }, + { + utils::GetFullQualifiedOpName("QuickGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Unsqueeze", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {1}}, // ignore the axes (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Where", kOnnxDomain), + { + {9, {}}, + {16, {}}, + }, + }, + + }); + + constexpr const int advanced_op_level = static_cast(ProbeLevel::Advanced); + recomputable_op_table_map.insert({advanced_op_level, InlinedHashMap()}); + auto& advanced_recomputable_op_table = recomputable_op_table_map.at(advanced_op_level); + // Append basic_recomputable_op_table to advanced_recomputable_op_table. + advanced_recomputable_op_table.insert(recomputable_op_table_map.at(basic_op_level).begin(), + recomputable_op_table_map.at(basic_op_level).end()); + + advanced_recomputable_op_table.insert({ + { + utils::GetFullQualifiedOpName("BiasSoftmax", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasSoftmaxDropout", kMSDomain), + { + {1, {2}}, // ignore ratio (optional) + }, + }, + { + utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), + { + // Opset 1 in ONNX official does not have LayerNormalization, + // while our contrib op defined LayerNormalization in opset 1 in ONNX domain. + {1, {}}, + {17, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), + { + {1, {}}, + {9, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("FusedMatMul", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain), + { + // Opset 1 in ONNX official does not have SimplifiedLayerNormalization, + // while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain. + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Softmax", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {}}, + }, + }, + }); + + return recomputable_op_table_map; +} - recomputable_op_table_map.insert({probe_op_level, InlinedHashMap()}); - auto& recomputable_op_table = recomputable_op_table_map.at(probe_op_level); - if (probe_op_level >= static_cast(ProbeLevel::Basic)) { - recomputable_op_table.insert({ - { - utils::GetFullQualifiedOpName("Add", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {7, {}}, - {13, {}}, - {14, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("BatchNormalization", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {7, {}}, - {9, {}}, - {14, {}}, - {15, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("BiasDropout", kMSDomain), - { - {1, {3, 4}}, // ignore ratio (optional) and training mode (optional) - }, - }, - { - utils::GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain), - { - {1, {3, 4}}, // ignore ratio (optional) and training mode (optional) - }, - }, - { - utils::GetFullQualifiedOpName("BitmaskDropout", kMSDomain), - { - {1, {1, 2}}, // ignore ratio (optional) and training mode (optional) - }, - }, - { - utils::GetFullQualifiedOpName("Cast", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {9, {}}, - {13, {}}, - {19, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("ConcatTraining", kMSDomain), - { - {1, {}}, - - }, - }, - { - utils::GetFullQualifiedOpName("ConstantOfShape", kOnnxDomain), - { - {9, {0}}, // ignore the `input`, e.g. the shape of the expected output tensor - {20, {0}}, - }, - }, - { - utils::GetFullQualifiedOpName("Cos", kOnnxDomain), - { - {7, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("CumSum", kOnnxDomain), - { - // The axis input is trivial - {11, {1}}, - {14, {1}}, - }, - }, - { - utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), - { - // ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support. - {12, {1, 2}}, // ignore ratio and training_mode - {13, {1, 2}}, - }, - }, - { - utils::GetFullQualifiedOpName("Div", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {7, {}}, - {13, {}}, - {14, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Einsum", kOnnxDomain), - { - {12, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Equal", kOnnxDomain), - { - {1, {}}, - {7, {}}, - {11, {}}, - {13, {}}, - {19, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Expand", kOnnxDomain), - { - {8, {1}}, // Ignore the shape. - {13, {1}}, - }, - }, - { - utils::GetFullQualifiedOpName("FastGelu", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("FlattenAndUnpad", kMSDomain), - { - {1, {1}}, // ignore the indices - }, - }, - { - utils::GetFullQualifiedOpName("Gather", kOnnxDomain), - { - {1, {1}}, // ignore the indices - {11, {1}}, - {13, {1}}, - }, - }, - { - utils::GetFullQualifiedOpName("Gelu", kOnnxDomain), - { - {20, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Gelu", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Gemm", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {7, {}}, - {9, {}}, - {11, {}}, - {13, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Less", kOnnxDomain), - { - {1, {}}, - {7, {}}, - {9, {}}, - {13, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("MemcpyFromHost", kOnnxDomain), - { - {1, {0}}, // Ignore CPU input. - }, - }, - { - utils::GetFullQualifiedOpName("Mul", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {7, {}}, - {13, {}}, - {14, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Neg", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {13, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("NonZero", kOnnxDomain), - { - {9, {}}, - {13, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("PadAndUnflatten", kMSDomain), - { - {1, {1, 2}}, // ignore the indices and unflatten_dims - }, - }, - { - // Be noted, NOT all PythonOp will be allowed to recompute, there will be further check. - utils::GetFullQualifiedOpName("PythonOp", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Range", kOnnxDomain), - { - {11, {0, 1, 2}}, // ignore start, end, delta, because they are scalars. - }, - }, - { - utils::GetFullQualifiedOpName("Reshape", kOnnxDomain), - { - {1, {}}, - {5, {}}, // ignore the shape. - {13, {}}, - {14, {}}, - {19, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Sin", kOnnxDomain), - { - {7, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Slice", kOnnxDomain), - { - {1, {}}, - {10, {1, 2, 3, 4}}, // ignore starts, ends, axes (optional) and steps (optional) - {11, {1, 2, 3, 4}}, - {13, {1, 2, 3, 4}}, - }, - }, - { - utils::GetFullQualifiedOpName("Split", kOnnxDomain), - { - {1, {1}}, // ignore split (optional) - {2, {}}, - {11, {}}, - {13, {1}}, // ignore the split (optional) - {18, {1}}, - }, - }, - { - utils::GetFullQualifiedOpName("Squeeze", kOnnxDomain), - { - {1, {}}, - {11, {}}, - {13, {1}}, // ignore the axes (optional) - }, - }, - { - utils::GetFullQualifiedOpName("Sub", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {7, {}}, - {13, {}}, - {14, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Tile", kOnnxDomain), - { - {1, {1, 2}}, - {6, {1}}, - {13, {1}}, - }, - }, - { - utils::GetFullQualifiedOpName("Transpose", kOnnxDomain), - { - {1, {}}, - {13, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Trilu", kOnnxDomain), - { - {14, {1}}, // ignore k (optional) - }, - }, - { - utils::GetFullQualifiedOpName("QuickGelu", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Unsqueeze", kOnnxDomain), - { - {1, {}}, - {11, {}}, - {13, {1}}, // ignore the axes (optional) - }, - }, - { - utils::GetFullQualifiedOpName("Where", kOnnxDomain), - { - {9, {}}, - {16, {}}, - }, - }, - - }); - } +const InlinedHashMap& GetAllowedRecomputeOps(int probe_op_level) { + static InlinedHashMap> + recomputable_op_table_map = InitializeRecomputableOpTable(); - if (probe_op_level >= static_cast(ProbeLevel::Advanced)) { - recomputable_op_table.insert({ - { - utils::GetFullQualifiedOpName("BiasSoftmax", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("BiasSoftmaxDropout", kMSDomain), - { - {1, {2}}, // ignore ratio (optional) - }, - }, - { - utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), - { - // Opset 1 in ONNX official does not have LayerNormalization, - // while our contrib op defined LayerNormalization in opset 1 in ONNX domain. - {1, {}}, - {17, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), - { - {1, {}}, - {9, {}}, - {13, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("FusedMatMul", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain), - { - // Opset 1 in ONNX official does not have SimplifiedLayerNormalization, - // while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain. - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Softmax", kOnnxDomain), - { - {1, {}}, - {11, {}}, - {13, {}}, - }, - }, - }); - } + ORT_ENFORCE(recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end(), + "Cannot get recomputable op table, probe level: ", probe_op_level); - return recomputable_op_table; + return recomputable_op_table_map.at(probe_op_level); } /** diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index 90c97eed0c6d..be25eefb201d 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -542,8 +542,9 @@ TEST(TrainingApiTest, OptimStep) { std::string param_name = "fc2.weight"; // before training, check if optim state is initialized to 0 onnxruntime::training::api::OptimizerCheckpointState& optimizer_states = state.optimizer_checkpoint_state; + std::shared_ptr group0_states = optimizer_states.group_named_optimizer_states["group0"]; onnxruntime::training::api::ParameterOptimizerState& param_state = - optimizer_states.group_named_optimizer_states["group0"]->param_named_optimizer_states.at(param_name); + group0_states->param_named_optimizer_states.at(param_name); OrtValue& moment_1 = param_state.at("momentum0"); std::vector param_vec_before_optimizer_step; diff --git a/orttraining/orttraining/training_api/checkpoint.cc b/orttraining/orttraining/training_api/checkpoint.cc index 56029b34c24d..cbff1891b8c8 100644 --- a/orttraining/orttraining/training_api/checkpoint.cc +++ b/orttraining/orttraining/training_api/checkpoint.cc @@ -449,7 +449,7 @@ Status FromOptimizerState(const OptimizerCheckpointState& optimizer_state, fbs_optimizer_groups.reserve(optimizer_state.group_named_optimizer_states.size()); for (const auto& group_name : SortedKeys(optimizer_state.group_named_optimizer_states)) { - const std::shared_ptr& group_optimizer_state_ptr = + std::shared_ptr group_optimizer_state_ptr = optimizer_state.group_named_optimizer_states.at(group_name); std::vector> optimizer_states; diff --git a/pyproject.toml b/pyproject.toml index 1c3a719fb544..6429df2722b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,7 @@ ignore = [ "G004", # FIXME: Enable when the rule can be autofixed "N803", # Argument casing "N812", # Allow import torch.nn.functional as F + "N813", # Allow importing camelcase names in lowercase "N999", # Module names "NPY002", # np.random.Generator may not always fit our use cases "PERF203", # "try-except-in-loop" only affects Python <3.11, and the improvement is minor; can have false positives diff --git a/setup.py b/setup.py index 51feedcfd328..1fa297e22acd 100644 --- a/setup.py +++ b/setup.py @@ -196,6 +196,7 @@ def run(self): to_preload_cann = [] cuda_dependencies = [ + "libcuda.so.1", "libcublas.so.11", "libcublas.so.12", "libcublasLt.so.11", diff --git a/tools/ci_build/get_docker_image.py b/tools/ci_build/get_docker_image.py index 99ecaf677f33..a3f603b0beda 100755 --- a/tools/ci_build/get_docker_image.py +++ b/tools/ci_build/get_docker_image.py @@ -98,17 +98,19 @@ def main(): ) if use_container_registry: + run(args.docker_path, "buildx", "create", "--driver=docker-container", "--name=container_builder") run( args.docker_path, "--log-level", "error", "buildx", "build", - "--push", + "--load", "--tag", full_image_name, - "--cache-from", - full_image_name, + "--cache-from=type=registry,ref=" + full_image_name, + "--builder", + "container_builder", "--build-arg", "BUILDKIT_INLINE_CACHE=1", *shlex.split(args.docker_build_args), @@ -116,24 +118,10 @@ def main(): args.dockerfile, args.context, ) - elif args.use_imagecache: - log.info("Building image with pipeline cache...") run( args.docker_path, - "--log-level", - "error", - "buildx", - "build", - "--tag", - full_image_name, - "--cache-from", + "push", full_image_name, - "--build-arg", - "BUILDKIT_INLINE_CACHE=1", - *shlex.split(args.docker_build_args), - "-f", - args.dockerfile, - args.context, ) else: log.info("Building image...") diff --git a/tools/ci_build/github/android/mobile_package.required_operators.config b/tools/ci_build/github/android/mobile_package.required_operators.config deleted file mode 100644 index 6a6ba8c3c90e..000000000000 --- a/tools/ci_build/github/android/mobile_package.required_operators.config +++ /dev/null @@ -1,46 +0,0 @@ -# Android package for ORT Mobile operator and type reduction configuration -# -# The list of operators was generated from: -# - the ONNX operators use by the tf2onnx tflite converter -# - the operators used in a set of tflite models from tfhub, the tflite examples, and the mlperf mobile models -# - models were optimized with optimizations set to 'basic', 'extended' and 'all' -# - see the readme file for full details - -# allow float, int8, uint8. operators that manipulate shapes or indices have int32 and int64 enabled internally. -!globally_allowed_types;float,int8_t,uint8_t - -# ops used by the tf2onnx tflite converter. -ai.onnx;12,13,14,15;Abs,Add,And,ArgMax,ArgMin,AveragePool,Cast,Ceil,Clip,Concat,ConstantOfShape,Conv,ConvTranspose,Cos,CumSum,DepthToSpace,DequantizeLinear,Div,DynamicQuantizeLinear,Elu,Equal,Exp,Expand,Flatten,Floor,Gather,GatherND,Gemm,Greater,GreaterOrEqual,Identity,If,LRN,LeakyRelu,Less,LessOrEqual,Log,LogSoftmax,Loop,MatMul,Max,MaxPool,Mean,Min,Mul,Neg,NonMaxSuppression,NonZero,Not,Or,PRelu,Pad,Pow,QuantizeLinear,Range,Reciprocal,ReduceMax,ReduceMean,ReduceMin,ReduceProd,ReduceSum,Relu,Reshape,Resize,ReverseSequence,Round,ScatterND,Shape,Sigmoid,Sin,Size,Slice,Softmax,SpaceToDepth,Split,Sqrt,Squeeze,Sub,Sum,Tanh,ThresholdedRelu,Tile,TopK,Transpose,Unique,Unsqueeze,Where - -# other ops found in test models -ai.onnx;12,13,14,15;Erf,GlobalAveragePool,InstanceNormalization,HardSigmoid,MatMulInteger,QLinearConv,QLinearMatMul - -# Control flow ops -# - If and Loop are covered by the tflite converter list -# - Scan tends to be used in speech models (it's more efficient than Loop) so include it for support of those -ai.onnx;12,13,14,15;Scan - -# Changed ONNX ops by opset version for the above ops. This list is to provide context as to how much was added -# for each additional opset we support. -# -# opset 13 -# Abs,Add,ArgMax,ArgMin,Cast,Ceil,Clip,Concat,DepthToSpace,DequantizeLinear,Div,Equal,Erf,Exp,Expand,Flatten,Floor, -# Gather,GatherND,Gemm,Greater,Identity,If,LRN,Less,Log,LogSoftmax,Loop,MatMul,Max,Mean,Min,Mul,Neg,NonZero,Pad, -# Pow,QuantizeLinear,Reciprocal,ReduceMax,ReduceMean,ReduceMin,ReduceProd,ReduceSum,Relu,Reshape,Resize, -# ScatterND,Shape,Sigmoid,Size,Slice,Softmax,SpaceToDepth,Split,Sqrt,Squeeze,Sub,Sum,Tanh,Tile,Transpose,Unsqueeze -# opset 14 -# Add,CumSum,Div,Identity,Mul,Relu,Reshape,Sub -# opset 15 -# Pow,Shape - - -# internal ops added by optimizers -# Note: LayerNormalization is an internal op even though it is (incorrectly) registered in the ONNX domain. -ai.onnx;1;LayerNormalization -com.microsoft;1;DynamicQuantizeMatMul,FusedConv,FusedGemm,FusedMatMul,Gelu,MatMulIntegerToFloat,NhwcMaxPool,QLinearAdd,QLinearAveragePool,QLinearConv,QLinearGlobalAveragePool,QLinearMul,QLinearSigmoid - -# NHWC transformer also uses this, so assuming it's valuable enough to include -com.microsoft;1;QLinearLeakyRelu - -# Quantized contrib ops that are registered but no usage was found. Excluding for now. -# com.microsoft;1;DynamicQuantizeLSTM,QAttention diff --git a/tools/ci_build/github/android/mobile_package.required_operators.readme.txt b/tools/ci_build/github/android/mobile_package.required_operators.readme.txt deleted file mode 100644 index 9e60cba4a42f..000000000000 --- a/tools/ci_build/github/android/mobile_package.required_operators.readme.txt +++ /dev/null @@ -1,82 +0,0 @@ -The required operators config file was generated from a number of models (details below), with optimizations run using 'all', 'extended' and 'basic'. -Following that, some additional operators were added, as per the comments in the config file. - -The global types to support were selected to support quantized and float32 models -Additionally there is internal 'required' type support for int32 and int64_t in selected operators that work with the dimensions in a shape or indices so that we don't need to enable those types at a global level. - -Models used as input (Converted using tf2onnx in early March 2021): - Models from TF Lite Examples https://www.tensorflow.org/lite/examples - - lite-model_deeplabv3_1_metadata_2.tflite.onnx - - lite-model_esrgan-tf2_1.tflite.onnx - - lite-model_mobilebert_1_metadata_1.tflite.onnx - - mnist.tflite.onnx - - mobilenet_v1_1.0_224_quant.tflite.onnx - - model_history10_top100.tflite.onnx - - posenet_mobilenet_float_075_1_default_1.tflite.onnx - - posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite.onnx - - ssd_mobilenet_v1_1_metadata_1.tflite.onnx - - text_classification_v2.tflite.onnx - -Assorted models from TF Hub that were able to be converted with tf2onnx - TFLite v1 https://tfhub.dev/s?deployment-format=lite&tf-version=tf1 - - efficientnet_lite1_fp32_2.tflite.onnx - - efficientnet_lite1_int8_2.tflite.onnx - - efficientnet_lite4_fp32_2.tflite.onnx - - efficientnet_lite4_int8_2.tflite.onnx - - lite-model_aiy_vision_classifier_birds_V1_3.tflite.onnx - - lite-model_aiy_vision_classifier_food_V1_1.tflite.onnx - - lite-model_aiy_vision_classifier_plants_V1_3.tflite.onnx - - lite-model_midas_v2_1_small_1_lite_1.tflite.onnx - - lite-model_object_detection_mobile_object_labeler_v1_1.tflite.onnx - - magenta_arbitrary-image-stylization-v1-256_int8_prediction_1.tflite.onnx - - magenta_arbitrary-image-stylization-v1-256_int8_transfer_1.tflite.onnx - - object_detection_mobile_object_localizer_v1_1_default_1.tflite.onnx - - TFLite v2 https://tfhub.dev/s?deployment-format=lite&tf-version=tf2 - - tf2\albert_lite_base_squadv1_1.tflite.onnx - - tf2\lite-model_disease-classification_1.tflite.onnx - - tf2\lite-model_efficientdet_lite0_detection_default_1.tflite.onnx - - tf2\lite-model_efficientdet_lite0_int8_1.tflite.onnx - - tf2\lite-model_efficientdet_lite1_detection_default_1.tflite.onnx - - tf2\lite-model_efficientdet_lite2_detection_default_1.tflite.onnx - - tf2\lite-model_efficientdet_lite3_detection_default_1.tflite.onnx - - tf2\lite-model_efficientdet_lite4_detection_default_1.tflite.onnx - - tf2\lite-model_esrgan-tf2_1.tflite.onnx - - tf2\lite-model_german-mbmelgan_lite_1.tflite.onnx - - tf2\lite-model_nonsemantic-speech-benchmark_trill-distilled_1.tflite.onnx - - tf2\lite-model_yamnet_tflite_1.tflite.onnx - -Models from MLPerf Mobile - (mainly models converted from TFLite and quantized in different ways, but some from TF for completeness as those also have batch handling) - - deeplabv3_mnv2_ade20k_float-int8.onnx - - deeplabv3_mnv2_ade20k_float.onnx - - deeplabv3_mnv2_ade20k-qdq.onnx - - mobilebert-int8.onnx - - mobilebert-qdq.onnx - - mobilebert.onnx - - mobiledet-int8.onnx - - mobiledet-qdq.onnx - - mobiledet.onnx - - mobilenet_edgetpu_224_1.0_float-int8.onnx - - mobilenet_edgetpu_224_1.0_float.onnx - - mobilenet_edgetpu_224_1.0-qdq.onnx - - mobilenet_v1_1.0_224.opset12.onnx - - resnet50_v1-int8.onnx - - resnet50_v1.onnx - - ssd_mobilenet_v2_300_float-int8.onnx - - ssd_mobilenet_v2_300_float.onnx - - ssd_mobilenet_v2_300-qdq.onnx - -Other - Mobilenet v2 and v3 from pytorch - - https://pytorch.org/vision/stable/models.html - - pytorch.mobilenet_v2_float.onnx - - pytorch.mobilenet_v2_uint8.onnx - - pytorch.mobilenet_v3_small.onnx - Other assorted pytorch models - - Huggingface mobilebert-uncased (https://huggingface.co/transformers/serialization.html, https://huggingface.co/google/mobilebert-uncased) - - SuperResolution (https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html) - - DeepLabV3 (https://pytorch.org/tutorials/beginner/deeplabv3_on_android.html) - - EfficientNet (https://github.com/lukemelas/EfficientNet-PyTorch) - - SSD Mobilenet V1 and V2 (https://github.com/qfgaohao/pytorch-ssd) - - Wav2Vec 2.0 (adapted from https://github.com/pytorch/ios-demo-app/blob/f2b9aa196821c136d3299b99c5dd592de1fa1776/SpeechRecognition/create_wav2vec2.py) diff --git a/tools/ci_build/github/android/run_nnapi_code_coverage.sh b/tools/ci_build/github/android/run_nnapi_code_coverage.sh deleted file mode 100755 index 472e824eaa47..000000000000 --- a/tools/ci_build/github/android/run_nnapi_code_coverage.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - -# This script will run ORT build for Android with code coverage option - -set -e -set -x - -if [ $# -ne 1 ]; then - echo "One command line argument, the ROOT root directory, is expected" -fi - -ORT_ROOT=$1 -# Build and run onnxruntime using NNAPI execution provider targeting android emulator -python3 ${ORT_ROOT}/tools/ci_build/build.py \ - --android \ - --build_dir build_nnapi \ - --android_sdk_path $ANDROID_HOME \ - --android_ndk_path $ANDROID_NDK_HOME \ - --android_abi=x86_64 \ - --android_api=29 \ - --skip_submodule_sync \ - --parallel \ - --use_nnapi \ - --cmake_generator=Ninja \ - --build_java \ - --path_to_protoc_exe $ORT_ROOT/protobuf_install/bin/protoc \ - --code_coverage - -# Install gcovr -python3 -m pip install gcovr - -# Retrieve runtime code coverage files from the emulator and analyze -python3 ${ORT_ROOT}/tools/ci_build/coverage.py \ - --build_dir build_nnapi \ - --android_sdk_path $ANDROID_HOME - diff --git a/tools/ci_build/github/apple/assemble_apple_packaging_artifacts.sh b/tools/ci_build/github/apple/assemble_apple_packaging_artifacts.sh index 317048506ac6..a2178337e687 100755 --- a/tools/ci_build/github/apple/assemble_apple_packaging_artifacts.sh +++ b/tools/ci_build/github/apple/assemble_apple_packaging_artifacts.sh @@ -23,10 +23,13 @@ ORT_POD_VERSION=${4:?${USAGE_TEXT}} POD_ARCHIVE_BASENAME="pod-archive-${POD_NAME}-${ORT_POD_VERSION}.zip" PODSPEC_BASENAME="${POD_NAME}.podspec" +echo "Contents of ${BINARIES_STAGING_DIR}/${POD_NAME}:" +ls -lR "${BINARIES_STAGING_DIR}/${POD_NAME}" + pushd "${BINARIES_STAGING_DIR}/${POD_NAME}" # assemble the files in the artifacts staging directory -zip -r "${ARTIFACTS_STAGING_DIR}/${POD_ARCHIVE_BASENAME}" ./* --exclude "${PODSPEC_BASENAME}" +zip -r -y "${ARTIFACTS_STAGING_DIR}/${POD_ARCHIVE_BASENAME}" ./* --exclude "${PODSPEC_BASENAME}" cp "${PODSPEC_BASENAME}" "${ARTIFACTS_STAGING_DIR}/${PODSPEC_BASENAME}" popd diff --git a/tools/ci_build/github/apple/build_and_assemble_apple_pods.py b/tools/ci_build/github/apple/build_and_assemble_apple_pods.py index 5014ba11d983..71aeb9e7b030 100755 --- a/tools/ci_build/github/apple/build_and_assemble_apple_pods.py +++ b/tools/ci_build/github/apple/build_and_assemble_apple_pods.py @@ -57,6 +57,11 @@ def parse_args(): ) parser.add_argument("--test", action="store_true", help="Run tests on the framework and pod package files.") + parser.add_argument( + "--skip-build", + action="store_true", + help="Use build from previous run. Useful to debug test issues or packaging changes.", + ) build_framework_group = parser.add_argument_group( title="iOS framework build arguments", @@ -114,7 +119,8 @@ def main(): build_apple_framework_args += ["--build_dir", str(build_dir), args.build_settings_file] - run(build_apple_framework_args) + if not args.skip_build: + run(build_apple_framework_args) if args.test: test_apple_packages_args = [ @@ -171,7 +177,8 @@ def main(): def move_dir(src, dst): if dst.is_dir(): shutil.rmtree(dst) - shutil.move(src, dst) + shutil.copytree(src, dst, symlinks=True) + shutil.rmtree(src) move_dir(c_pod_staging_dir, staging_dir / c_pod_name) move_dir(objc_pod_staging_dir, staging_dir / objc_pod_name) diff --git a/tools/ci_build/github/apple/build_apple_framework.py b/tools/ci_build/github/apple/build_apple_framework.py index 3cd7a3af7062..5a3b242c2a38 100644 --- a/tools/ci_build/github/apple/build_apple_framework.py +++ b/tools/ci_build/github/apple/build_apple_framework.py @@ -89,18 +89,52 @@ def _build_for_apple_sysroot( pathlib.Path(framework_dir).mkdir(parents=True, exist_ok=True) # copy the Info.plist, framework_info.json, and header files - shutil.copy(info_plist_path, framework_dir) - shutil.copy(framework_info_path, os.path.dirname(framework_dir)) - header_dir = os.path.join(framework_dir, "Headers") - pathlib.Path(header_dir).mkdir(parents=True, exist_ok=True) - for _header in headers: - shutil.copy(_header, header_dir) - - # use lipo to create a fat ort library - lipo_command = ["lipo", "-create"] - lipo_command += ort_libs - lipo_command += ["-output", os.path.join(framework_dir, "onnxruntime")] - subprocess.run(lipo_command, shell=False, check=True) + + # macos requires different framework structure: + # https://developer.apple.com/library/archive/documentation/MacOSX/Conceptual/BPFrameworks/Concepts/FrameworkAnatomy.html + if sysroot == "macosx" or sysroot == "macabi": + # create headers and resources directory + header_dir = os.path.join(framework_dir, "Versions", "A", "Headers") + resource_dir = os.path.join(framework_dir, "Versions", "A", "Resources") + pathlib.Path(header_dir).mkdir(parents=True, exist_ok=True) + pathlib.Path(resource_dir).mkdir(parents=True, exist_ok=True) + + shutil.copy(info_plist_path, resource_dir) + shutil.copy(framework_info_path, os.path.dirname(framework_dir)) + + for _header in headers: + shutil.copy(_header, header_dir) + + # use lipo to create a fat ort library + lipo_command = ["lipo", "-create"] + lipo_command += ort_libs + lipo_command += ["-output", os.path.join(framework_dir, "Versions", "A", "onnxruntime")] + subprocess.run(lipo_command, shell=False, check=True) + + # create the symbolic link + pathlib.Path(os.path.join(framework_dir, "Versions", "Current")).symlink_to("A", target_is_directory=True) + pathlib.Path(os.path.join(framework_dir, "Headers")).symlink_to( + "Versions/Current/Headers", target_is_directory=True + ) + pathlib.Path(os.path.join(framework_dir, "Resources")).symlink_to( + "Versions/Current/Resources", target_is_directory=True + ) + pathlib.Path(os.path.join(framework_dir, "onnxruntime")).symlink_to("Versions/Current/onnxruntime") + + else: + shutil.copy(info_plist_path, framework_dir) + shutil.copy(framework_info_path, os.path.dirname(framework_dir)) + header_dir = os.path.join(framework_dir, "Headers") + pathlib.Path(header_dir).mkdir(parents=True, exist_ok=True) + + for _header in headers: + shutil.copy(_header, header_dir) + + # use lipo to create a fat ort library + lipo_command = ["lipo", "-create"] + lipo_command += ort_libs + lipo_command += ["-output", os.path.join(framework_dir, "onnxruntime")] + subprocess.run(lipo_command, shell=False, check=True) return framework_dir @@ -166,7 +200,7 @@ def _build_package(args): xcframework_dir = os.path.join(build_dir, "framework_out") pathlib.Path(xcframework_dir).mkdir(parents=True, exist_ok=True) shutil.copy(os.path.join(REPO_DIR, "LICENSE"), xcframework_dir) - shutil.copytree(public_headers_path, os.path.join(xcframework_dir, "Headers"), dirs_exist_ok=True) + shutil.copytree(public_headers_path, os.path.join(xcframework_dir, "Headers"), dirs_exist_ok=True, symlinks=True) _merge_framework_info_files(framework_info_files_to_merge, os.path.join(build_dir, "xcframework_info.json")) # remove existing xcframework if any diff --git a/tools/ci_build/github/apple/c/assemble_c_pod_package.py b/tools/ci_build/github/apple/c/assemble_c_pod_package.py index ca4f01cf65bd..59052734ddd2 100644 --- a/tools/ci_build/github/apple/c/assemble_c_pod_package.py +++ b/tools/ci_build/github/apple/c/assemble_c_pod_package.py @@ -16,6 +16,7 @@ PackageVariant, copy_repo_relative_to_dir, gen_file_from_template, + get_podspec_values, load_json_config, ) @@ -66,23 +67,25 @@ def assemble_c_pod_package( print("Warning: staging directory already exists", file=sys.stderr) # copy the necessary files to the staging directory - shutil.copytree(framework_dir, staging_dir / framework_dir.name, dirs_exist_ok=True) - shutil.copytree(public_headers_dir, staging_dir / public_headers_dir.name, dirs_exist_ok=True) + shutil.copytree(framework_dir, staging_dir / framework_dir.name, dirs_exist_ok=True, symlinks=True) + shutil.copytree(public_headers_dir, staging_dir / public_headers_dir.name, dirs_exist_ok=True, symlinks=True) copy_repo_relative_to_dir(["LICENSE"], staging_dir) + (ios_deployment_target, macos_deployment_target, weak_framework) = get_podspec_values(framework_info) + # generate the podspec file from the template variable_substitutions = { "DESCRIPTION": pod_config["description"], # By default, we build both "iphoneos" and "iphonesimulator" architectures, and the deployment target should be the same between these two. - "IOS_DEPLOYMENT_TARGET": framework_info["iphonesimulator"]["APPLE_DEPLOYMENT_TARGET"], - "MACOSX_DEPLOYMENT_TARGET": framework_info.get("macosx", {}).get("APPLE_DEPLOYMENT_TARGET", ""), + "IOS_DEPLOYMENT_TARGET": ios_deployment_target, + "MACOSX_DEPLOYMENT_TARGET": macos_deployment_target, "LICENSE_FILE": "LICENSE", "NAME": pod_name, "ORT_C_FRAMEWORK": framework_dir.name, "ORT_C_HEADERS_DIR": public_headers_dir.name, "SUMMARY": pod_config["summary"], "VERSION": pod_version, - "WEAK_FRAMEWORK": framework_info["iphonesimulator"]["WEAK_FRAMEWORK"], + "WEAK_FRAMEWORK": weak_framework, } podspec_template = _script_dir / "c.podspec.template" diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index 5609033fc3e3..b546c266c131 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -6,13 +6,16 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Add|| |ai.onnx:AveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:Clip|| +|ai.onnx:Concat|| |ai.onnx:Conv|Only 1D/2D Conv is supported.
Bias if provided must be constant.| |ai.onnx:ConvTranspose|Weight and bias must be constant.
padding_type of SAME_UPPER/SAME_LOWER is not supported.
kernel_shape must have default values.
output_shape is not supported.
output_padding must have default values.| +|ai.onnx.DepthToSpace|If 'mode' is 'CRD' the input must have a fixed shape.| |ai.onnx:Div|| |ai.onnx:Gemm|Input B must be constant.| |ai.onnx:GlobalAveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:GlobalMaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:GridSample|4D input.
'mode' of 'linear' or 'zeros'.
(mode==linear && padding_mode==reflection && align_corners==0) is not supported.| +|ai.onnx.LeakyRelu|| |ai.onnx:MatMul|Only support for transA == 0, alpha == 1.0 and beta == 1.0 is currently implemented.| |ai.onnx:MaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:Mul|| @@ -21,7 +24,8 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Reshape|| |ai.onnx:Resize|See [resize_op_builder.cc](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc) implementation. There are too many permutations to describe the valid combinations.| |ai.onnx.Slice|starts/ends/axes/steps must be constant initializers.| +|ai.onnx:Split|| |ai.onnx:Sub|| |ai.onnx:Sigmoid|| |ai:onnx:Tanh|| -|ai:onnx:Transpose|| +|ai.onnx:Transpose|| diff --git a/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py b/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py index 1e26482440ea..b7eb34cb0921 100755 --- a/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py +++ b/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py @@ -17,6 +17,7 @@ copy_repo_relative_to_dir, filter_files, gen_file_from_template, + get_podspec_values, load_json_config, ) @@ -147,12 +148,14 @@ def assemble_objc_pod_package( def path_patterns_as_variable_value(patterns: list[str]): return ", ".join([f'"{pattern}"' for pattern in patterns]) + (ios_deployment_target, macos_deployment_target, _) = get_podspec_values(framework_info) + variable_substitutions = { "C_POD_NAME": c_pod_config["name"], "DESCRIPTION": pod_config["description"], "INCLUDE_DIR_LIST": path_patterns_as_variable_value(include_dirs), - "IOS_DEPLOYMENT_TARGET": framework_info["iphonesimulator"]["APPLE_DEPLOYMENT_TARGET"], - "MACOSX_DEPLOYMENT_TARGET": framework_info.get("macosx", {}).get("APPLE_DEPLOYMENT_TARGET", ""), + "IOS_DEPLOYMENT_TARGET": ios_deployment_target, + "MACOSX_DEPLOYMENT_TARGET": macos_deployment_target, "LICENSE_FILE": license_file, "NAME": pod_name, "PUBLIC_HEADER_FILE_LIST": path_patterns_as_variable_value(pod_files["public_header_files"]), diff --git a/tools/ci_build/github/apple/package_assembly_utils.py b/tools/ci_build/github/apple/package_assembly_utils.py index 8ab8ccdb3f96..c6822466d73d 100644 --- a/tools/ci_build/github/apple/package_assembly_utils.py +++ b/tools/ci_build/github/apple/package_assembly_utils.py @@ -118,6 +118,44 @@ def load_json_config(json_config_file: pathlib.Path): return json.load(config) +def get_podspec_values(framework_info): + """ + Get the podspec deployement targets and weak framework info from the dictionary that load_json_config returned. + Looks for iphonesimulator, iphoneos and macos settings. + Handles missing platforms and checks consistency. + Returns empty string for deployment target if that platofrm is not enabled. + + :return (ios_deployment_target, macos_deployment_target, weak_framework) + """ + ios_deployment_target = "" + macos_deployment_target = "" + weak_framework = "" # should be the same for all platforms + # get info, allowing for a subset of platforms to be specified + for framework in ("iphonesimulator", "iphoneos", "macosx"): + if framework not in framework_info: + continue + + target = framework_info[framework]["APPLE_DEPLOYMENT_TARGET"] + weak = framework_info[framework]["WEAK_FRAMEWORK"] + + if not weak_framework: + weak_framework = weak + else: + # should be consistent + assert weak == weak_framework + + if framework == "macosx": + macos_deployment_target = target + else: + if not ios_deployment_target: + ios_deployment_target = target + else: + # should be consistent + assert ios_deployment_target == target + + return (ios_deployment_target, macos_deployment_target, weak_framework) + + def get_ort_version(): """ Gets the ONNX Runtime version string from the repo. diff --git a/tools/ci_build/github/apple/test_apple_packages.py b/tools/ci_build/github/apple/test_apple_packages.py index 8f06d6dd68fb..14c0b46676ac 100644 --- a/tools/ci_build/github/apple/test_apple_packages.py +++ b/tools/ci_build/github/apple/test_apple_packages.py @@ -89,8 +89,9 @@ def _test_apple_packages(args): # create a zip file contains the framework zip_file_path = local_pods_dir / f"{pod_name}.zip" - # shutil.make_archive require target file as full path without extension - shutil.make_archive(zip_file_path.with_suffix(""), "zip", root_dir=local_pods_dir) + + # shutil.make_archive doesn't preserve symlinks. we know this is running on macOS so use zip + subprocess.run(["zip", "-r", "-y", str(zip_file_path), "."], cwd=local_pods_dir, check=True) # update the podspec to point to the local framework zip file with open(podspec) as file: diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index 6649206c0d79..3fba9f54f266 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -31,11 +31,11 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.24.0.240626 + default: 2.25.0.240728 jobs: - job: Build_QNN_EP - pool: onnxruntime-qnn-ubuntu-2004-cpu + pool: onnxruntime-Ubuntu2204-AMD-CPU timeoutInMinutes: 30 workspace: clean: all @@ -46,6 +46,10 @@ jobs: inputs: versionSpec: $(pythonVersion) + - script: | + env | grep ANDROID + displayName: View Android ENVs + - script: sudo apt-get update -y && sudo apt-get install -y coreutils ninja-build displayName: Install coreutils and ninja @@ -56,13 +60,6 @@ jobs: parameters: QnnSDKVersion: ${{ parameters.QnnSdk }} - - script: | - export ANDROID_SDK_ROOT=/usr/local/lib/android/sdk - export ANDROID_HOME=/usr/local/lib/android/sdk - export ANDROID_NDK_HOME=/usr/local/lib/android/sdk/ndk-bundle - export ANDROID_NDK_ROOT=/usr/local/lib/android/sdk/ndk-bundle - displayName: set Android ENVs - - script: | set -e -x rm -rf /tmp/scripts diff --git a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml index 10d9a9a24d88..bcfe4cde9ce5 100644 --- a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml @@ -174,10 +174,10 @@ stages: - template: templates/clean-agent-build-directory-step.yml -- stage: MASTER_BUILD_STAGE - # The below jobs only run on master build. +- stage: MAIN_BUILD_STAGE + # The below jobs only run on build of main branch. # because coverage report is hard to support in cross machines. - displayName: NNAPI MASTER BUILD&TEST + displayName: NNAPI MAIN BUILD&TEST dependsOn: [] condition: in(variables['Build.Reason'], 'IndividualCI', 'BatchedCI') jobs: @@ -225,29 +225,29 @@ stages: --code_coverage displayName: NNAPI EP, Build, Test, CodeCoverage on Android Emulator + # We need to use llvm-cov from the NDK. - script: | - python3 -m pip install gcovr && \ - python3 tools/ci_build/coverage.py \ - --build_dir build_nnapi \ - --android_sdk_path $ANDROID_HOME + export GCOV="$ANDROID_NDK_HOME/toolchains/llvm/prebuilt/linux-x86_64/bin/llvm-cov gcov" + python3 -m pip install gcovr + python3 tools/ci_build/coverage.py --build_dir build_nnapi --android_sdk_path $ANDROID_HOME displayName: Retrieve runtime code coverage files from the emulator and analyze - script: cat '$(Build.SourcesDirectory)/build_nnapi/Debug/coverage_rpt.txt' displayName: Print coverage report + # - task: AzureCLI@2 + # displayName: 'Post Android Code Coverage To DashBoard' + # inputs: + # azureSubscription: AIInfraBuild + # scriptType: bash + # scriptPath: $(Build.SourcesDirectory)/tools/ci_build/github/linux/upload_code_coverage_data.sh + # arguments: '"$(Build.SourcesDirectory)/build_nnapi/Debug/coverage_rpt.txt" "https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=$(Build.BuildId)" arm android nnapi' + # workingDirectory: '$(Build.BinariesDirectory)' + - script: /bin/bash tools/ci_build/github/linux/ort_minimal/nnapi_minimal_build_minimal_ort_and_run_tests.sh $(pwd) # Build Minimal ORT with NNAPI and reduced Ops, run unit tests on Android Emulator displayName: Build Minimal ORT with NNAPI and run tests - - task: AzureCLI@2 - displayName: 'Post Android Code Coverage To DashBoard' - inputs: - azureSubscription: AIInfraBuild - scriptType: bash - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/github/linux/upload_code_coverage_data.sh - arguments: '"$(Build.SourcesDirectory)/build_nnapi/Debug/coverage_rpt.txt" "https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=$(Build.BuildId)" arm android nnapi' - workingDirectory: '$(Build.BinariesDirectory)' - - template: templates/use-android-emulator.yml parameters: stop: true diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index a66828ee5e18..4a3532dd57fa 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -321,6 +321,7 @@ stages: --build-arg TRT_VERSION=${{ variables.linux_trt_version }} " Repository: onnxruntimeubi8packagestest_torch + UseImageCacheContainerRegistry: false UpdateDepsTxt: false - task: DownloadPackage@1 diff --git a/tools/ci_build/github/azure-pipelines/binary-size-checks-pipeline.yml b/tools/ci_build/github/azure-pipelines/binary-size-checks-pipeline.yml index e9762bc31245..74866cfd59b5 100644 --- a/tools/ci_build/github/azure-pipelines/binary-size-checks-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/binary-size-checks-pipeline.yml @@ -13,21 +13,9 @@ resources: ref: 5eda9aded5462201e6310105728d33016e637ea7 stages: - -# checks enabled in all builds - - template: templates/android-binary-size-check-stage.yml parameters: Name: MinimalBaseline BuildConfigFile: "tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_baseline.config" BinarySizeThresholdInBytes: 1306224 DoBuildWithDebugInfo: ${{ parameters.DoBuildWithDebugInfo }} - -# checks excluded from PR builds - -- ${{ if ne(variables['Build.Reason'], 'PullRequest') }}: - - template: templates/android-binary-size-check-stage.yml - parameters: - Name: MinimalWithMobilePackageOps - BuildConfigFile: "tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_with_mobile_package_ops.config" - DoBuildWithDebugInfo: ${{ parameters.DoBuildWithDebugInfo }} diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 2eb7046d80e7..c9210b996b84 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -62,7 +62,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.24.0.240626 + default: 2.25.0.240728 resources: repositories: @@ -112,17 +112,6 @@ stages: SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} -- template: templates/ondevice-training-cpu-packaging-pipeline.yml - parameters: - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - DoCompliance: ${{ parameters.DoCompliance }} - DoEsrp: ${{ parameters.DoEsrp }} - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - OrtNugetPackageId: 'Microsoft.ML.OnnxRuntime.Training' - AdditionalBuildFlags: '--enable_training_apis' - AdditionalWinBuildFlags: '--enable_onnx_tests --enable_wcos' - BuildVariant: 'default' - - template: stages/java-cuda-packaging-stage.yml parameters: CudaVersion: 11.8 diff --git a/tools/ci_build/github/azure-pipelines/c-api-training-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-training-packaging-pipelines.yml new file mode 100644 index 000000000000..22ee7de8a5de --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/c-api-training-packaging-pipelines.yml @@ -0,0 +1,74 @@ +parameters: +- name: RunOnnxRuntimeTests + displayName: Run Tests? + type: boolean + default: true + +- name: DoCompliance + displayName: Run Compliance Tasks? + type: boolean + default: true + +- name: DoEsrp + displayName: Run code sign tasks? Must be true if you are doing an ONNX Runtime release + type: boolean + default: true + +- name: IsReleaseBuild + displayName: Is a release build? Set it to true if you are doing an ONNX Runtime release. + type: boolean + default: false +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + default: none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + default: 0 + +# these 2 parameters are used for debugging. +- name: SpecificArtifact + displayName: Use Specific Artifact (Debugging only) + type: boolean + default: false + +- name: BuildId + displayName: Pipeline BuildId, you could find it in the URL + type: string + default: '0' + +stages: +- template: stages/set_packaging_variables_stage.yml + parameters: + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} + +- template: templates/ondevice-training-cpu-packaging-pipeline.yml + parameters: + RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} + DoCompliance: ${{ parameters.DoCompliance }} + DoEsrp: ${{ parameters.DoEsrp }} + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + OrtNugetPackageId: 'Microsoft.ML.OnnxRuntime.Training' + AdditionalBuildFlags: '--enable_training_apis' + AdditionalWinBuildFlags: '--enable_onnx_tests --enable_wcos' + BuildVariant: 'default' + +- template: templates/publish-nuget-steps.yml + parameters: + download_artifacts_steps: + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - Signed NuGet Training Package' + ArtifactName: 'drop-signed-nuget-Training-CPU' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact/final-package' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 30f56f4b18ae..d3e4a2e00959 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -30,7 +30,7 @@ parameters: - name: CudaVersion displayName: CUDA version type: string - default: '11.8' + default: '12.2' values: - 11.8 - 12.2 diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml index 78e3b166995e..5c7108861052 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml @@ -30,7 +30,7 @@ parameters: - name: CudaVersion displayName: CUDA version type: string - default: '11.8' + default: '12.2' values: - 11.8 - 12.2 diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml index 7cfff805c3b3..4ab1b4996a1d 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml @@ -8,14 +8,12 @@ parameters: - name: TrtVersion displayName: TensorRT Version type: string - default: 10.0.cuda_11_8_cudnn_8 + default: 10.2.cuda_12_5_cudnn_9 values: - - 8.4.cuda_11_6_cudnn_8 - - 8.5.cuda_11_8_cudnn_8 - 8.6.cuda_11_8_cudnn_8 - 8.6.cuda_12_3_cudnn_9 - - 10.0.cuda_11_8_cudnn_8 - - 10.0.cuda_12_4_cudnn_9 + - 10.2.cuda_11_8_cudnn_8 + - 10.2.cuda_12_5_cudnn_9 - BIN - name: UseTensorrtOssParser diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 0d67b0947be5..9282792a6b41 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.24.0.240626 + default: 2.25.0.240728 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/nuget-cuda-publishing-pipeline.yml b/tools/ci_build/github/azure-pipelines/nuget-cuda-publishing-pipeline.yml index 4bfd726f5c58..aeb250e1e0cb 100644 --- a/tools/ci_build/github/azure-pipelines/nuget-cuda-publishing-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/nuget-cuda-publishing-pipeline.yml @@ -6,6 +6,7 @@ resources: branches: include: - main + - rel-* branch: main parameters: @@ -16,15 +17,15 @@ parameters: variables: - name: ArtifactFeed ${{ if eq(parameters.isReleaseBuild, false) }}: - value: ort-cuda-12-nightly + value: ORT-Nightly ${{ else }}: value: onnxruntime-cuda-12 stages: -- template: stages/nuget-cuda-publishing-stage.yml - parameters: - artifact_feed: $(ArtifactFeed) + - template: stages/nuget-cuda-publishing-stage.yml + parameters: + artifact_feed: $(ArtifactFeed) -- template: stages/java-cuda-publishing-stage.yml - parameters: - artifact_feed: $(ArtifactFeed) + - template: stages/java-cuda-publishing-stage.yml + parameters: + artifact_feed: $(ArtifactFeed) \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml index 5fa80bf7ff6d..1fa88318b8c0 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml @@ -22,3 +22,5 @@ stages: enable_windows_gpu: false enable_mac_cpu: true enable_linux_arm: false + enable_windows_arm64_qnn: false + enable_windows_x64_qnn: false diff --git a/tools/ci_build/github/azure-pipelines/publish-nuget.yml b/tools/ci_build/github/azure-pipelines/publish-nuget.yml index e0c588413415..b78d586288ba 100644 --- a/tools/ci_build/github/azure-pipelines/publish-nuget.yml +++ b/tools/ci_build/github/azure-pipelines/publish-nuget.yml @@ -9,10 +9,22 @@ resources: - rel-* branch: main +parameters: + - name: isReleaseBuild + type: boolean + default: false + +variables: + - name: ArtifactFeed + ${{ if eq(parameters.isReleaseBuild, false) }}: + value: ort-cuda-11-nightly + ${{ else }}: + value: onnxruntime-cuda-11 + stages: - template: templates/publish-nuget-steps.yml parameters: - stage_name: 'Publish_NuGet_Packag_And_Report' + stage_name: 'Publish_NuGet_Package_And_Report' include_cpu_ep: true download_artifacts_steps: - download: build @@ -20,12 +32,11 @@ stages: artifact: 'drop-signed-nuget-dml' - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-dml\*" $(Build.BinariesDirectory)\nuget-artifact\final-package - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Package' - artifact: 'drop-signed-nuget-Training-CPU' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-Training-CPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package + # Publish CUDA 11 Nuget/Java pkgs to ADO feed + - template: stages/nuget-cuda-publishing-stage.yml + parameters: + artifact_feed: $(ArtifactFeed) - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Package' - artifact: 'drop-signed-nuget-GPU' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-GPU\*" $(Build.BinariesDirectory)\nuget-artifact\final-package + - template: stages/java-cuda-publishing-stage.yml + parameters: + artifact_feed: $(ArtifactFeed) diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-publishing-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-publishing-pipeline.yml index 50e0ca3708d2..1217163c0713 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-publishing-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-publishing-pipeline.yml @@ -16,7 +16,7 @@ parameters: variables: - name: ArtifactFeed ${{ if eq(parameters.isReleaseBuild, false) }}: - value: ort-cuda-12-nightly + value: ORT-Nightly ${{ else }}: value: onnxruntime-cuda-12 diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index cd3966633d74..c7a1b595a6c6 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.24.0.240626 + default: 2.25.0.240728 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 7229bc5dbd11..25d50f4255cb 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.24.0.240626 + default: 2.25.0.240728 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/stages/java-cuda-publishing-stage.yml b/tools/ci_build/github/azure-pipelines/stages/java-cuda-publishing-stage.yml index 70d92286b396..946d651b795d 100644 --- a/tools/ci_build/github/azure-pipelines/stages/java-cuda-publishing-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/java-cuda-publishing-stage.yml @@ -8,7 +8,7 @@ stages: jobs: - job: JAR_Publishing_GPU #TD-DO: figure out a way to package nightly jar. Currently Java version are set from VERSION_NUMBER file - condition: ${{ eq(parameters.artifact_feed, 'onnxruntime-cuda-12') }} + condition: ${{ or(eq(parameters.artifact_feed, 'onnxruntime-cuda-11'), eq(parameters.artifact_feed, 'onnxruntime-cuda-12')) }} workspace: clean: all pool: 'onnxruntime-Win-CPU-2022' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 7ba1179e7ad4..0368c91290d5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -111,7 +111,7 @@ stages: cp -R $(Build.BinariesDirectory)/ios_framework/framework_out/onnxruntime.xcframework \ $(Build.BinariesDirectory)/artifacts_staging/onnxruntime-ios-xcframework-$(OnnxRuntimeVersion) pushd $(Build.BinariesDirectory)/artifacts_staging - zip -vr $(Build.BinariesDirectory)/artifacts/onnxruntime_xcframework.zip \ + zip -vry $(Build.BinariesDirectory)/artifacts/onnxruntime_xcframework.zip \ onnxruntime-ios-xcframework-$(OnnxRuntimeVersion) popd displayName: "Build Apple xcframework" @@ -364,6 +364,8 @@ stages: workingDirectory: '$(Build.BinariesDirectory)/nuget-artifact' displayName: 'List artifacts' + - template: set-version-number-variables-step.yml + # Reconstruct the build dir - task: PowerShell@2 displayName: 'Extract native libraries for addition to nuget native package' @@ -403,7 +405,7 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' platform: 'Any CPU' configuration: RelWithDebInfo - msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)' + msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) -p:PackageVersion=$(OnnxRuntimeVersion)' workingDirectory: '$(Build.SourcesDirectory)\csharp' - ${{ if eq(parameters.DoEsrp, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml index e2b71c5c55fd..0f4328f75e1b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml @@ -51,15 +51,15 @@ jobs: Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile Context: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{parameters.BaseImage}}" - Repository: onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}} - + Repository: onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}}_packaging + - ${{ if eq(parameters.OnnxruntimeArch, 'aarch64') }}: - template: get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile Context: tools/ci_build/github/linux/docker/inference/aarch64/default/cpu DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{parameters.BaseImage}}" - Repository: onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}} + Repository: onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}}_packaging UpdateDepsTxt: false - task: CmdLine@2 @@ -67,7 +67,7 @@ jobs: script: | mkdir -p $HOME/.onnx docker run --rm --volume /data/onnx:/data/onnx:ro --volume $(Build.SourcesDirectory):/onnxruntime_src --volume $(Build.BinariesDirectory):/build \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}} /bin/bash -c "python3.9 \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}}_packaging /bin/bash -c "python3.9 \ /onnxruntime_src/tools/ci_build/build.py --enable_lto --build_java --build_nodejs --build_dir /build --config Release \ --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib ${{ parameters.AdditionalBuildFlags }} && cd /build/Release && make install DESTDIR=/build/installed" workingDirectory: $(Build.SourcesDirectory) diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index bf11730c2ce2..45c79a677b68 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.167 + version: 1.0.175 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.167 + version: 1.0.175 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/get-docker-image-steps.yml b/tools/ci_build/github/azure-pipelines/templates/get-docker-image-steps.yml index 94cdf042ec62..5b6769685a97 100644 --- a/tools/ci_build/github/azure-pipelines/templates/get-docker-image-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/get-docker-image-steps.yml @@ -53,6 +53,7 @@ steps: displayName: patch manylinux - script: | + docker version docker image ls docker system df displayName: Check Docker Images @@ -71,52 +72,25 @@ steps: displayName: "Get ${{ parameters.Repository }} image for ${{ parameters.Dockerfile }}" ContainerRegistry: onnxruntimebuildcache - ${{ if eq(parameters.UseImageCacheContainerRegistry, false) }}: - - task: Cache@2 - displayName: Cache Docker Image Task - inputs: - key: ' "${{ parameters.Repository }}" | "$(Build.SourceVersion)" ' - path: ${{ parameters.IMAGE_CACHE_DIR }} - restoreKeys: | - "${{ parameters.Repository }}" | "$(Build.SourceVersion)" - "${{ parameters.Repository }}" - cacheHitVar: CACHE_RESTORED - condition: eq('${{ parameters.UsePipelineCache }}', 'true') - - - script: | - test -f ${{ parameters.IMAGE_CACHE_DIR }}/cache.tar && docker load -i ${{ parameters.IMAGE_CACHE_DIR }}/cache.tar - docker image ls - displayName: Docker restore - condition: eq('${{ parameters.UsePipelineCache }}', 'true') - - - script: | - if [ ${{ parameters.UsePipelineCache}} ] - then - use_imagecache="--use_imagecache" - else - use_imagecache="" - fi - ${{ parameters.ScriptName }} \ - --dockerfile "${{ parameters.Dockerfile }}" \ - --context "${{ parameters.Context }}" \ - --docker-build-args "${{ parameters.DockerBuildArgs }}" \ - --repository "${{ parameters.Repository }}" \ - $use_imagecache - displayName: "Get ${{ parameters.Repository }} image for ${{ parameters.Dockerfile }}" - - - script: | - set -ex - mkdir -p "${{ parameters.IMAGE_CACHE_DIR }}" - docker save -o "${{ parameters.IMAGE_CACHE_DIR }}/cache.tar" ${{ parameters.Repository }} - docker image ls - docker system df - displayName: Docker save - condition: eq('${{ parameters.UsePipelineCache }}', 'true') + # the difference is no --container-registry + - template: with-container-registry-steps.yml + parameters: + Steps: + - script: | + ${{ parameters.ScriptName }} \ + --dockerfile "${{ parameters.Dockerfile }}" \ + --context "${{ parameters.Context }}" \ + --docker-build-args "${{ parameters.DockerBuildArgs }}" \ + --repository "${{ parameters.Repository }}" + displayName: "Get ${{ parameters.Repository }} image for ${{ parameters.Dockerfile }}" + ContainerRegistry: onnxruntimebuildcache - - script: | - echo ${{ parameters.IMAGE_CACHE_DIR }} - ls -lah ${{ parameters.IMAGE_CACHE_DIR }} - displayName: Display docker dir - condition: eq('${{ parameters.UsePipelineCache }}', 'true') +- script: | + docker version + docker image ls + docker system df + df -h + displayName: Check Docker Images - ${{ if and(eq(parameters.UpdateDepsTxt, true), or(eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29'),eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c'))) }}: - task: PythonScript@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 734ad43e0066..e727ec4f7ef5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.24.0.240626' + default: '2.25.0.240728' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml index de29a3de9fde..6459888a40ae 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml @@ -7,7 +7,7 @@ parameters: default: false - name: CudaVersion type: string - default: '11.8' + default: '12.2' values: - 11.8 - 12.2 diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index 900adc969025..912cac6fbb99 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.24.0.240626' + default: '2.25.0.240728' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml index 63d521f1e7d9..fba463b49016 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml @@ -9,10 +9,10 @@ parameters: default: false - name: PrimaryCUDAVersion type: string - default: '11.8' + default: '12.2' - name: SecondaryCUDAVersion type: string - default: '12.2' + default: '11.8' steps: - ${{ if eq(parameters.DownloadCUDA, 'true') }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml index fb9ff65fe853..022f85cc0a46 100644 --- a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml @@ -317,3 +317,4 @@ stages: ArtifactSuffix: 'Training-CPU' StageSuffix: 'Training_CPU' NativePackagePrefix: 'onnxruntime-training' + CustomOpArtifactName: 'onnxruntime-training-linux-x64' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 447e35244eb6..faf453140052 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -63,7 +63,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.24.0.240626 + default: 2.25.0.240728 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 40e8583141df..c3a2b7be7ebd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.24.0.240626 + default: 2.25.0.240728 - name: PYTHON_VERSION type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 33335bb2be2d..5cf03a7cdd10 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.24.0.240626 + default: 2.25.0.240728 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 944745b69ca6..c7fd26712329 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.24.0.240626' + QnnSdk: '2.25.0.240728' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml deleted file mode 100644 index 438e51175c5b..000000000000 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ /dev/null @@ -1,118 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -parameters: -- name: RunOnnxRuntimeTests - displayName: Run Tests? - type: boolean - default: true - -stages: -- stage: cuda - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda.bat - buildArch: x64 - additionalBuildFlags: >- - --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" - --enable_cuda_profiling --enable_transformers_tool_test - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 - --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON - --cmake_extra_defines onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON - msbuildPlatform: x64 - isX86: false - job_name_suffix: x64_RelWithDebInfo - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - ORT_EP_NAME: CUDA - WITH_CACHE: true - MachinePool: onnxruntime-Win2022-GPU-A10 - -- stage: training - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda.bat - buildArch: x64 - additionalBuildFlags: >- - --enable_pybind --enable_training --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" - --skip_onnx_tests - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 - msbuildPlatform: x64 - isX86: false - job_name_suffix: x64_RelWithDebInfo - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - ORT_EP_NAME: CUDA - WITH_CACHE: true - MachinePool: onnxruntime-Win2022-GPU-A10 - isTraining: true - -- stage: dml - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat - buildArch: x64 - additionalBuildFlags: --enable_pybind --use_dml --enable_wcos --use_winml - msbuildPlatform: x64 - isX86: false - job_name_suffix: x64_RelWithDebInfo - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - ORT_EP_NAME: DML - WITH_CACHE: false - MachinePool: onnxruntime-Win2022-GPU-dml-A10 - -- stage: kernelDocumentation - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda.bat - buildArch: x64 - # note: need to specify `--gen_doc` when creating the build config so it has to be in additionalBuildFlags - additionalBuildFlags: >- - --gen_doc validate --skip_tests --enable_pybind --use_dml --use_cuda - --cuda_home="$(Agent.TempDirectory)\v11.8" - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 - --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF - msbuildPlatform: x64 - isX86: false - job_name_suffix: x64_RelWithDebInfo - RunOnnxRuntimeTests: false - GenerateDocumentation: true - ORT_EP_NAME: CUDA # It doesn't really matter which EP is selected here since this stage is for documentation. - WITH_CACHE: true - MachinePool: onnxruntime-Win2022-GPU-A10 diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-cuda-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-cuda-ci-pipeline.yml new file mode 100644 index 000000000000..78e1624b5d12 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/win-gpu-cuda-ci-pipeline.yml @@ -0,0 +1,64 @@ +##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +trigger: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +pr: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +#### end trigger #### + +parameters: +- name: CudaVersion + displayName: CUDA version + type: string + default: '12.2' + values: + - 11.8 + - 12.2 +- name: RunOnnxRuntimeTests + displayName: Run Tests? + type: boolean + default: true + +stages: +- stage: cuda + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + EnvSetupScript: setup_env_cuda.bat + buildArch: x64 + additionalBuildFlags: >- + --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" + --enable_cuda_profiling --enable_transformers_tool_test + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON + --cmake_extra_defines onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_RelWithDebInfo + RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} + ORT_EP_NAME: CUDA + WITH_CACHE: true + MachinePool: onnxruntime-Win2022-GPU-A10 \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-dml-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-dml-ci-pipeline.yml new file mode 100644 index 000000000000..904979f39ca3 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/win-gpu-dml-ci-pipeline.yml @@ -0,0 +1,52 @@ +##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +trigger: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +pr: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +#### end trigger #### + +parameters: +- name: RunOnnxRuntimeTests + displayName: Run Tests? + type: boolean + default: true + +stages: +- stage: dml + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + EnvSetupScript: setup_env.bat + buildArch: x64 + additionalBuildFlags: --enable_pybind --use_dml --enable_wcos --use_winml + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_RelWithDebInfo + RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} + ORT_EP_NAME: DML + WITH_CACHE: false + MachinePool: onnxruntime-Win2022-GPU-dml-A10 \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-doc-gen-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-doc-gen-ci-pipeline.yml new file mode 100644 index 000000000000..410688933135 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/win-gpu-doc-gen-ci-pipeline.yml @@ -0,0 +1,61 @@ +##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +trigger: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +pr: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +#### end trigger #### + +parameters: +- name: CudaVersion + displayName: CUDA version + type: string + default: '12.2' + values: + - 11.8 + - 12.2 + +stages: +- stage: kernelDocumentation + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + EnvSetupScript: setup_env_cuda.bat + buildArch: x64 + # note: need to specify `--gen_doc` when creating the build config so it has to be in additionalBuildFlags + additionalBuildFlags: >- + --gen_doc validate --skip_tests --enable_pybind --use_dml --use_cuda + --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_RelWithDebInfo + RunOnnxRuntimeTests: false + GenerateDocumentation: true + ORT_EP_NAME: CUDA # It doesn't really matter which EP is selected here since this stage is for documentation. + WITH_CACHE: true + MachinePool: onnxruntime-Win2022-GPU-A10 diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml index 70c0c7d4a04e..8c9ecdfb9019 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml @@ -26,6 +26,21 @@ pr: - 'js/web' - 'onnxruntime/core/providers/js' #### end trigger #### +parameters: +- name: CudaVersion + displayName: CUDA version + type: string + default: '12.2' + values: + - 11.8 + - 12.2 + +variables: + - name: win_trt_folder + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: TensorRT-10.2.0.19.Windows10.x86_64.cuda-12.5 jobs: - job: 'build' @@ -55,7 +70,7 @@ jobs: WithCache: True Today: $(TODAY) AdditionalKey: "gpu-tensorrt | RelWithDebInfo" - BuildPyArguments: '--config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86' + BuildPyArguments: '--config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\${{ variables.win_trt_folder }}" --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86' MsbuildArguments: $(MsbuildArguments) BuildArch: 'x64' Platform: 'x64' @@ -75,7 +90,7 @@ jobs: del wheel_filename_file python.exe -m pip install -q --upgrade %WHEEL_FILENAME% set PATH=$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo;%PATH% - python $(Build.SourcesDirectory)\tools\ci_build\build.py --config RelWithDebInfo --use_binskim_compliant_compile_flags --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 + python $(Build.SourcesDirectory)\tools\ci_build\build.py --config RelWithDebInfo --use_binskim_compliant_compile_flags --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\${{ variables.win_trt_folder }}" --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' displayName: 'Run tests' diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-training-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-training-ci-pipeline.yml new file mode 100644 index 000000000000..3bb6c267f001 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/win-gpu-training-ci-pipeline.yml @@ -0,0 +1,63 @@ +##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +trigger: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +pr: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +#### end trigger #### + +parameters: +- name: CudaVersion + displayName: CUDA version + type: string + default: '12.2' + values: + - 11.8 + - 12.2 +- name: RunOnnxRuntimeTests + displayName: Run Tests? + type: boolean + default: true + +stages: +- stage: training + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + EnvSetupScript: setup_env_cuda.bat + buildArch: x64 + additionalBuildFlags: >- + --enable_pybind --enable_training --use_cuda --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" + --skip_onnx_tests + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_RelWithDebInfo + RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} + ORT_EP_NAME: CUDA + WITH_CACHE: true + MachinePool: onnxruntime-Win2022-GPU-A10 + isTraining: true diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index e1b8b718e992..31cdbeb99be4 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.24.0.240626 + default: 2.25.0.240728 jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 97c4ab15095c..54277bcb4039 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.24.0.240626 + default: 2.25.0.240728 jobs: - job: 'build' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index d96b34297427..07885ba65af8 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -2,7 +2,7 @@ # Please overwrite BASEIMAGE, TRT_VERSION and other arguments with # --docker-build-args ' --build-arg BASEIMAGE=other_base_image --build-arg TRT_VERSION=other_trt_version etc...' # for other cuda version and TRT version -ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 +ARG BASEIMAGE=nvidia/cuda:12.5.1-cudnn-devel-ubi8 FROM $BASEIMAGE ARG TRT_VERSION diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 index 2d3dc05285e3..b587a7df554b 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 @@ -2,11 +2,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------- -# Dockerfile to Test ONNX Runtime on UBI8 with TensorRT 10.0 and CUDA 11.8 by default +# Dockerfile to Test ONNX Runtime on UBI8 with TensorRT 10 and CUDA 12 by default # Build base image with required system packages -ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 -ARG TRT_VERSION=10.2.0.19-1.cuda11.8 +ARG BASEIMAGE=nvidia/cuda:12.5.1-cudnn-devel-ubi8 +ARG TRT_VERSION=10.2.0.19-1.cuda12.4 FROM $BASEIMAGE AS base ARG TRT_VERSION ENV PATH /opt/python/cp38-cp38/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile index 2cd054e6246b..ca00050121d6 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile @@ -5,7 +5,7 @@ ARG BASEIMAGE=arm64v8/almalinux:8 FROM $BASEIMAGE -ENV PATH /opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV PATH=/opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile index caf9583807b6..ef28dde67617 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -5,7 +5,7 @@ ARG BASEIMAGE=amd64/almalinux:8 FROM $BASEIMAGE -ENV PATH /usr/lib/jvm/msopenjdk-11/bin:/opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV PATH=/usr/lib/jvm/msopenjdk-11/bin:/opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 diff --git a/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_with_mobile_package_ops.config b/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_with_mobile_package_ops.config deleted file mode 100644 index dbebec5788dd..000000000000 --- a/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_with_mobile_package_ops.config +++ /dev/null @@ -1,19 +0,0 @@ -{ - "type": "minimal-with-mobile-package-ops", - "os": "android", - "arch": "arm64-v8a", - "build_params": [ - "--enable_lto", - "--android", - "--android_sdk_path=/android_home", - "--android_ndk_path=/ndk_home", - "--android_abi=arm64-v8a", - "--android_api=29", - "--minimal_build", - "--build_shared_lib", - "--build_java", - "--disable_ml_ops", - "--disable_exceptions", - "--include_ops_by_config=/onnxruntime_src/tools/ci_build/github/android/mobile_package.required_operators.config" - ] -} diff --git a/tools/ci_build/github/windows/eager/requirements.txt b/tools/ci_build/github/windows/eager/requirements.txt index 08e7baa76471..b285defd89f5 100644 --- a/tools/ci_build/github/windows/eager/requirements.txt +++ b/tools/ci_build/github/windows/eager/requirements.txt @@ -3,5 +3,5 @@ wheel numpy==1.21.6 ; python_version < '3.9' numpy==2.0.0 ; python_version >= '3.9' typing_extensions -torch==1.13.1 +torch==2.2.0 parameterized diff --git a/tools/ci_build/github/windows/install_third_party_deps.ps1 b/tools/ci_build/github/windows/install_third_party_deps.ps1 index 07679006fb34..168df9018879 100644 --- a/tools/ci_build/github/windows/install_third_party_deps.ps1 +++ b/tools/ci_build/github/windows/install_third_party_deps.ps1 @@ -27,7 +27,7 @@ $Env:CMAKE_PREFIX_PATH = "$install_prefix" New-Item -Path "$install_prefix" -ItemType Directory -Force # Setup compile flags -$compile_flags = @('/MP', '/guard:cf', '/DWIN32', '/D_WINDOWS', '/DWINVER=0x0A00', '/D_WIN32_WINNT=0x0A00', '/DNTDDI_VERSION=0x0A000000', '/W3') +$compile_flags = @('/MP', '/guard:cf', '/DWIN32', '/D_WINDOWS', '/D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR', '/DWINVER=0x0A00', '/D_WIN32_WINNT=0x0A00', '/DNTDDI_VERSION=0x0A000000', '/W3') $linker_flags=@('/guard:cf') if ($use_cache) { diff --git a/tools/ci_build/github/windows/setup_env_cuda.bat b/tools/ci_build/github/windows/setup_env_cuda.bat index 2233f7611ab6..f93938e2a900 100644 --- a/tools/ci_build/github/windows/setup_env_cuda.bat +++ b/tools/ci_build/github/windows/setup_env_cuda.bat @@ -1,17 +1,17 @@ REM Copyright (c) Microsoft Corporation. All rights reserved. REM Licensed under the MIT License. -if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( -set PATH=%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64;%PATH% +if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( +set PATH=%AGENT_TEMPDIRECTORY%\v12.2\bin;%AGENT_TEMPDIRECTORY%\v12.2\extras\CUPTI\lib64;%PATH% ) else ( - set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64;%PATH% + set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64;%PATH% ) -@REM The default version is still cuda v11.8, because set cuda v12.2 after it -if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( - set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v12.2\bin;%AGENT_TEMPDIRECTORY%\v12.2\extras\CUPTI\lib64 +@REM The default version is still cuda v12.2, because set cuda v11.8 after it +if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( + set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64 ) else ( - set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64 + set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64 ) set GRADLE_OPTS=-Dorg.gradle.daemon=false diff --git a/tools/ci_build/github/windows/setup_env_gpu.bat b/tools/ci_build/github/windows/setup_env_gpu.bat index 6c59866ea925..35e4f7e30243 100644 --- a/tools/ci_build/github/windows/setup_env_gpu.bat +++ b/tools/ci_build/github/windows/setup_env_gpu.bat @@ -1,17 +1,17 @@ REM Copyright (c) Microsoft Corporation. All rights reserved. REM Licensed under the MIT License. -if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( - set PATH=%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64;%PATH% +if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( + set PATH=%AGENT_TEMPDIRECTORY%\v12.2\bin;%AGENT_TEMPDIRECTORY%\v12.2\extras\CUPTI\lib64;%PATH% ) else ( - set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64;%PATH% + set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64;%PATH% ) -set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8\lib;%PATH% +set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.2.0.19.Windows10.x86_64.cuda-12.5\lib;%PATH% -@REM The default version is still cuda v11.8, because set cuda v12.2 after it -set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\TensorRT-10.2.0.19.Windows10.x86_64.cuda-12.5\lib -if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( - set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v12.2\bin;%AGENT_TEMPDIRECTORY%\v12.2\extras\CUPTI\lib64 +@REM The default version is still cuda v12.2, because set cuda v11.8 after it +set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8\lib +if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( + set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64 ) else ( set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\\extras\CUPTI\lib64 ) diff --git a/tools/ci_build/github/windows/setup_env_trt.bat b/tools/ci_build/github/windows/setup_env_trt.bat index 249bb9881589..7ec7558edab3 100644 --- a/tools/ci_build/github/windows/setup_env_trt.bat +++ b/tools/ci_build/github/windows/setup_env_trt.bat @@ -1,11 +1,11 @@ REM Copyright (c) Microsoft Corporation. All rights reserved. REM Licensed under the MIT License. -if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( - set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64 +if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( + set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v12.2\bin;%AGENT_TEMPDIRECTORY%\v12.2\extras\CUPTI\lib64 ) else ( - set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64 + set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64 ) -set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.2.0.19.Windows10.x86_64.cuda-11.8\lib;%PATH% +set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.2.0.19.Windows10.x86_64.cuda-12.5\lib;%PATH% set GRADLE_OPTS=-Dorg.gradle.daemon=false set CUDA_MODULE_LOADING=LAZY \ No newline at end of file diff --git a/tools/ci_build/set-trigger-rules.py b/tools/ci_build/set-trigger-rules.py index d26fec41033c..0d90061e9c68 100644 --- a/tools/ci_build/set-trigger-rules.py +++ b/tools/ci_build/set-trigger-rules.py @@ -34,7 +34,10 @@ "orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml", "orttraining-mac-ci-pipeline.yml", "win-ci-pipeline.yml", - "win-gpu-ci-pipeline.yml", + "win-gpu-ci-dml-pipeline.yml", + "win-gpu-ci-cuda-pipeline.yml", + "win-gpu-ci-training-pipeline.yml", + "win-gpu-ci-doc-gen-pipeline.yml", "win-gpu-tensorrt-ci-pipeline.yml", "win-qnn-arm64-ci-pipeline.yml", "win-qnn-ci-pipeline.yml", diff --git a/tools/python/gen_ort_mobile_pkg_doc.py b/tools/python/gen_ort_mobile_pkg_doc.py deleted file mode 100644 index 482cb05bb50b..000000000000 --- a/tools/python/gen_ort_mobile_pkg_doc.py +++ /dev/null @@ -1,97 +0,0 @@ -import argparse -import os -import pathlib - -from util import reduced_build_config_parser -from util.ort_format_model.operator_type_usage_processors import GloballyAllowedTypesOpTypeImplFilter - - -def generate_docs(output_file, required_ops, op_type_impl_filter): - with open(output_file, "w") as out: - out.write("# ONNX Runtime Mobile Pre-Built Package Operator and Type Support\n\n") - - # Description - out.write("## Supported operators and types\n\n") - out.write( - "The supported operators and types are based on what is required to support float32 and quantized " - "versions of popular models. The full list of input models used to determine this list is available " - "[here](https://github.com/microsoft/onnxruntime/blob/main/tools/ci_build/github/android/mobile_package" - ".required_operators.readme.txt)" - ) - out.write("\n\n") - - # Globally supported types - out.write("## Supported data input types\n\n") - assert op_type_impl_filter.__class__ is GloballyAllowedTypesOpTypeImplFilter - global_types = op_type_impl_filter.global_type_list() - for type in sorted(global_types): - out.write(f" - {type}\n") - out.write("\n") - out.write("NOTE: Operators used to manipulate dimensions and indices will support int32 and int64.\n\n") - - domain_op_opsets = [] - for domain in sorted(required_ops.keys()): - op_opsets = {} - domain_op_opsets.append((domain, op_opsets)) - for opset in sorted(required_ops[domain].keys()): - str_opset = str(opset) - for op in required_ops[domain][opset]: - op_with_domain = f"{domain}:{op}" - if op_with_domain not in op_opsets: - op_opsets[op_with_domain] = [] - - op_opsets[op_with_domain].append(str_opset) - - out.write("## Supported Operators\n\n") - out.write("|Operator|Opsets|\n") - out.write("|--------|------|\n") - for domain, op_opsets in domain_op_opsets: - out.write(f"|**{domain}**||\n") - for op in sorted(op_opsets.keys()): - out.write("|{}|{}|\n".format(op, ", ".join(op_opsets[op]))) - out.write("|||\n") - - -def main(): - script_dir = os.path.dirname(os.path.realpath(__file__)) - - parser = argparse.ArgumentParser( - description="ONNX Runtime Mobile Pre-Built Package Operator and Type Support Documentation Generator", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - default_config_path = pathlib.Path( - os.path.join(script_dir, "../ci_build/github/android/mobile_package.required_operators.config") - ).resolve() - - default_output_path = pathlib.Path( - os.path.join(script_dir, "../../docs/ORTMobilePackageOperatorTypeSupport.md") - ).resolve() - - parser.add_argument( - "--config_path", - help="Path to build configuration used to generate package.", - required=False, - type=pathlib.Path, - default=default_config_path, - ) - - parser.add_argument( - "--output_path", - help="output markdown file path", - required=False, - type=pathlib.Path, - default=default_output_path, - ) - - args = parser.parse_args() - config_file = args.config_path.resolve(strict=True) # must exist so strict=True - output_path = args.output_path.resolve() - - enable_type_reduction = True - required_ops, op_type_impl_filter = reduced_build_config_parser.parse_config(config_file, enable_type_reduction) - generate_docs(output_path, required_ops, op_type_impl_filter) - - -if __name__ == "__main__": - main() diff --git a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py deleted file mode 100644 index 23bfce2e1c64..000000000000 --- a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py +++ /dev/null @@ -1,301 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -# Helper script that will check if the types and operators used in an ONNX model -# are supported by the pre-built ORT Mobile package. - -import argparse -import logging -import pathlib -import sys - -import onnx - -from ..onnx_model_utils import ModelProtoWithShapeInfo, get_opsets_imported -from ..reduced_build_config_parser import parse_config - -cpp_to_tensorproto_type = { - "float": 1, - "uint8_t": 2, - "int8_t": 3, - "uint16_t": 4, - "int16_t": 5, - "int32_t": 6, - "int64_t": 7, - "std::string": 8, - "bool": 9, - "MLFloat16": 10, - "double": 11, - "uint32_t": 12, - "uint64_t": 13, - "Complex64": 14, # not supported by ORT - "Complex128": 15, # not supported by ORT - "BFloat16": 16, -} - -tensorproto_type_to_cpp = {v: k for k, v in cpp_to_tensorproto_type.items()} - - -def check_graph(graph, opsets, required_ops, global_types, special_types, unsupported_ops, logger): - """ - Check the graph and any subgraphs for usage of types or operators which we know are not supported. - :param graph: Graph to process. - :param opsets: Map of domain to opset version that the model imports. - :param required_ops: Operators that are included in the pre-built package. - :param global_types: Types globally enabled in the pre-built package. - :param special_types: Types that are always enabled for a subset of operators and are _usually_ supported but are - not guaranteed to be. We would need to add a lot of infrastructure to know for sure so - currently we treat them as supported. - :param unsupported_ops: Set of unsupported operators that were found. - :param logger: Logger for diagnostic output. - :return: Returns whether the graph uses unsupported operators or types. - """ - has_unsupported_types = False - value_info_map = {vi.name: vi for vi in graph.value_info} - - def _is_type_supported(value_info, description): - is_supported = True - type_name = value_info.type.WhichOneof("value") - if type_name == "tensor_type": - t = value_info.type.tensor_type.elem_type - if t not in global_types and t not in special_types: - cpp_type = tensorproto_type_to_cpp[t] - logger.debug(f"Element type {cpp_type} of {description} is not supported.") - is_supported = False - else: - # we don't support sequences, map, sparse tensors, or optional types in the pre-built package - logger.debug(f"Data type {type_name} of {description} is not supported.") - is_supported = False - - return is_supported - - def _input_output_is_supported(value_info, input_output): - return _is_type_supported(value_info, f"graph {input_output} {value_info.name}") - - # node outputs are simpler to check. - # node inputs have a much wider mix of types, some of which come from initializers and most likely are always - # enabled as we generally do type reduction on the user data input to the operator and not the weights/etc. which - # come from initializers. - def _node_output_is_supported(name): - is_supported = True - if name in value_info_map: - vi = value_info_map[name] - is_supported = _is_type_supported(vi, f"node output {name}") - else: - # we don't have type info so ignore - pass - - return is_supported - - for i in graph.input: - if not _input_output_is_supported(i, "input"): - has_unsupported_types = True - - for o in graph.output: - if not _input_output_is_supported(o, "output"): - has_unsupported_types = True - - for node in graph.node: - # required_ops are map of [domain][opset] to set of op_type names. '' == ai.onnx - domain = node.domain or "ai.onnx" - - # special case Constant as we will convert to an initializer during model load - if domain == "ai.onnx" and node.op_type == "Constant": - continue - - # some models don't have complete imports. use 1 as a default as that's valid for custom domains and should - # result in an error for any others. not sure why ONNX or ORT validation allows this though. - opset = opsets.get(domain, 1) - if ( - domain not in required_ops - or opset not in required_ops[domain] - or node.op_type not in required_ops[domain][opset] - ): - unsupported_ops.add(f"{domain}:{opset}:{node.op_type}") - - for output_name in node.output: - if not _node_output_is_supported(output_name): - has_unsupported_types = True - - # recurse into subgraph for control flow nodes (Scan/Loop/If) - for attr in node.attribute: - if attr.HasField("g"): - check_graph(attr.g, opsets, required_ops, global_types, special_types, unsupported_ops, logger) - - return has_unsupported_types or unsupported_ops - - -def _get_global_tensorproto_types(op_type_impl_filter, logger: logging.Logger): - """ - Map the globally supported types (C++) to onnx.TensorProto.DataType values used in the model - See https://github.com/onnx/onnx/blob/1faae95520649c93ae8d0b403816938a190f4fa7/onnx/onnx.proto#L485 - - Additionally return a set of types we special case as being able to generally be considered as supported. - :param op_type_impl_filter: type filter from reduced build configuration parser - :param logger: Logger - :return: tuple of globally enabled types and special cased types - """ - global_cpp_types = op_type_impl_filter.global_type_list() - global_onnx_tensorproto_types = set() - - for t in global_cpp_types: - if t in cpp_to_tensorproto_type: - global_onnx_tensorproto_types.add(cpp_to_tensorproto_type[t]) - else: - logger.error(f"Error: Unexpected data type of {t} in package build config's globally enabled types.") - sys.exit(-1) - - # a subset of operators require int32 and int64 to always be enabled, as those types are used for dimensions in - # shapes and indices. - # additionally we have a number of operators (e.g. Not, Where) that always require the use of bool. - # this _may_ mean values involving these types can be processed, but without adding a lot more code we don't know - # for sure. - special_types = [ - cpp_to_tensorproto_type["int32_t"], - cpp_to_tensorproto_type["int64_t"], - cpp_to_tensorproto_type["bool"], - ] - - return global_onnx_tensorproto_types, special_types - - -def get_default_config_path(): - # get default path to config that was used to create the pre-built package. - script_dir = pathlib.Path(__file__).parent - local_config = script_dir / "mobile_package.required_operators.config" - - # if we're running in the ORT python package the file should be local. otherwise assume we're running from the - # ORT repo - if local_config.exists(): - default_config_path = local_config - else: - ort_root = script_dir.parents[3] - default_config_path = ( - ort_root / "tools" / "ci_build" / "github" / "android" / "mobile_package.required_operators.config" - ) - - return default_config_path - - -def run_check_with_model( - model_with_type_info: onnx.ModelProto, mobile_pkg_build_config: pathlib.Path, logger: logging.Logger -): - """ - Check if an ONNX model can be used with the ORT Mobile pre-built package. - :param model_with_type_info: ONNX model that has had ONNX shape inferencing run on to add type/shape information. - :param mobile_pkg_build_config: Configuration file used to build the ORT Mobile package. - :param logger: Logger for output - :return: True if supported - """ - if not mobile_pkg_build_config: - mobile_pkg_build_config = get_default_config_path() - - enable_type_reduction = True - config_path = str(mobile_pkg_build_config.resolve(strict=True)) - required_ops, op_type_impl_filter = parse_config(config_path, enable_type_reduction) - global_onnx_tensorproto_types, special_types = _get_global_tensorproto_types(op_type_impl_filter, logger) - - # get the opset imports - opsets = get_opsets_imported(model_with_type_info) - - # If the ONNX opset of the model is not supported we can recommend using our tools to update that first. - supported_onnx_opsets = set(required_ops["ai.onnx"].keys()) - # we have a contrib op that is erroneously in the ai.onnx domain with opset 1. manually remove that incorrect value - supported_onnx_opsets.remove(1) - onnx_opset_model_uses = opsets["ai.onnx"] - if onnx_opset_model_uses not in supported_onnx_opsets: - logger.info(f"Model uses ONNX opset {onnx_opset_model_uses}.") - logger.info(f"The pre-built package only supports ONNX opsets {sorted(supported_onnx_opsets)}.") - logger.info( - "Please try updating the ONNX model opset to a supported version using " - "python -m onnxruntime.tools.onnx_model_utils.update_onnx_opset ..." - ) - - return False - - unsupported_ops = set() - logger.debug( - "Checking if the data types and operators used in the model are supported in the pre-built ORT package..." - ) - unsupported = check_graph( - model_with_type_info.graph, - opsets, - required_ops, - global_onnx_tensorproto_types, - special_types, - unsupported_ops, - logger, - ) - - if unsupported_ops: - logger.info("Unsupported operators:") - for entry in sorted(unsupported_ops): - logger.info(" " + entry) # noqa: G003 - - if unsupported: - logger.info("\nModel is not supported by the pre-built package due to unsupported types and/or operators.") - logger.info( - "Please see https://onnxruntime.ai/docs/install/#install-on-web-and-mobile for information " - "on what is supported in the pre-built package." - ) - logger.info( - "The 'full' ORT package for Android (onnxruntime-android) or iOS (onnxruntime-{objc|c}) could be used, " - "or a custom build of ONNX Runtime will be required if binary size is critical. Please see " - "https://onnxruntime.ai/docs/build/custom.html for details on performing that." - ) - else: - logger.info("Model should work with the pre-built package.") - - logger.info("---------------\n") - - return not unsupported - - -def run_check(model_path: pathlib.Path, mobile_pkg_build_config: pathlib.Path, logger: logging.Logger): - """ - Check if an ONNX model will be able to be used with the ORT Mobile pre-built package. - :param model_path: Path to ONNX model. - :param mobile_pkg_build_config: Configuration file used to build the ORT Mobile package. - :param logger: Logger for output - :return: True if supported - """ - logger.info( - f"Checking if pre-built ORT Mobile package can be used with {model_path} once model is " - "converted from ONNX to ORT format using onnxruntime.tools.convert_onnx_models_to_ort..." - ) - - model_file = model_path.resolve(strict=True) - - # we need to run shape inferencing to populate that type info for node outputs. - # we will get warnings if the model uses ORT contrib ops (ONNX does not have shape inferencing for those), - # and shape inferencing will be lost downstream of those. - # TODO: add support for checking ORT format model as it will have full type/shape info for all nodes - model_wrapper = ModelProtoWithShapeInfo(model_file) - return run_check_with_model(model_wrapper.model_with_shape_info, mobile_pkg_build_config, logger) - - -def main(): - parser = argparse.ArgumentParser( - description="Check if model can be run using the ONNX Runtime Mobile Pre-Built Package", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument( - "--config_path", - help="Path to required operators and types configuration used to build the pre-built ORT mobile package.", - required=False, - type=pathlib.Path, - default=get_default_config_path(), - ) - - parser.add_argument("model_path", help="Path to ONNX model to check", type=pathlib.Path) - - args = parser.parse_args() - - logger = logging.getLogger("default") - logger.setLevel(logging.INFO) - run_check(args.model_path, args.config_path, logger) - - -if __name__ == "__main__": - main() diff --git a/tools/python/util/mobile_helpers/usability_checker.py b/tools/python/util/mobile_helpers/usability_checker.py index a8b5021f1387..e7948c43baa4 100644 --- a/tools/python/util/mobile_helpers/usability_checker.py +++ b/tools/python/util/mobile_helpers/usability_checker.py @@ -513,11 +513,11 @@ def check_nnapi_partitions(model, require_fixed_input_sizes: bool): return _check_ep_partitioning(model, config_path, require_fixed_input_sizes) -def check_coreml_partitions(model: onnx.ModelProto, require_fixed_input_sizes: bool, config_filename): +def check_coreml_partitions(model: onnx.ModelProto, require_fixed_input_sizes: bool, config_filename: str): # if we're running in the ORT python package the file should be local. otherwise assume we're running from the # ORT repo script_dir = pathlib.Path(__file__).parent - local_config = script_dir / "coreml_supported_ops.md" + local_config = script_dir / config_filename if local_config.exists(): config_path = local_config else: