diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index fc70fa14cff3..3965fe063b14 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -98,6 +98,13 @@ jobs: --exclude=onnxruntime/core/mlas/inc/* --exclude=onnxruntime/core/mlas/lib/* --exclude=onnxruntime/contrib_ops/cuda/bert/flash_attention/* + --exclude=build/Debug/* + --exclude=cmake/* + --exclude=csharp/test/* + --exclude=onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/* + --exclude=orttraining/orttraining/test/* + --exclude=onnxruntime/test/* + --exclude=winml/* filter: "-runtime/references" lint-js: diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 4eb865968cb5..29eb7045fc29 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -36,7 +36,7 @@ "component": { "type": "git", "git": { - "commitHash": "4a2c63365eff8823a5221db86ef490e828306f9d", + "commitHash": "f46495ea96f68fc3f6c394f099b2992743f6ff7f", "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" }, "comments": "abseil_cpp" diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index a9b0dfb30cc4..daacd221caa9 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -665,10 +665,10 @@ else() check_cxx_compiler_flag(-Wuseless-cast HAS_USELESS_CAST) check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW) if(onnxruntime_ENABLE_TRAINING_APIS) - check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE) - if(HAS_DANGLING_REFERENCE) - list(APPEND ORT_WARNING_FLAGS -Wno-dangling-reference) - endif() + check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE) + if(HAS_DANGLING_REFERENCE) + list(APPEND ORT_WARNING_FLAGS -Wno-dangling-reference) + endif() endif() check_function_exists(reallocarray HAS_REALLOCARRAY) if (NOT APPLE AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_target_platform STREQUAL "aarch64") @@ -845,8 +845,8 @@ if (onnxruntime_USE_QNN) file(GLOB QNN_LIB_FILES LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/libQnn*.so" "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/Qnn*.dll") if (${QNN_ARCH_ABI} STREQUAL "aarch64-windows-msvc" OR ${QNN_ARCH_ABI} STREQUAL "arm64x-windows-msvc") file(GLOB EXTRA_HTP_LIB LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/hexagon-v68/unsigned/libQnnHtpV68Skel.so" - "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so" - "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libqnnhtpv73.cat") + "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so" + "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libqnnhtpv73.cat") list(APPEND QNN_LIB_FILES ${EXTRA_HTP_LIB}) endif() message(STATUS "QNN lib files: " ${QNN_LIB_FILES}) @@ -1057,6 +1057,9 @@ function(onnxruntime_set_compile_flags target_name) foreach(FLAG ${ORT_WARNING_FLAGS}) target_compile_options(${target_name} PRIVATE "$<$:${FLAG}>") endforeach() + if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS 13 AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 12) + target_compile_options(${target_name} PRIVATE "$<$:-Wno-maybe-uninitialized>") + endif() if (onnxruntime_USE_CUDA) foreach(FLAG ${ORT_WARNING_FLAGS}) target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options ${FLAG}>") @@ -1198,11 +1201,11 @@ if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 if (onnxruntime_USE_ACL_2002) add_definitions(-DACL_2002=1) else() - if (onnxruntime_USE_ACL_2308) - add_definitions(-DACL_2308=1) - else() + if (onnxruntime_USE_ACL_2308) + add_definitions(-DACL_2308=1) + else() add_definitions(-DACL_1905=1) - endif() + endif() endif() endif() endif() @@ -1529,7 +1532,7 @@ if (onnxruntime_ENABLE_TRAINING) list(APPEND onnxruntime_EXTERNAL_LIBRARIES tensorboard) endif() -if (UNIX AND onnxruntime_USE_NCCL) +if (UNIX OR onnxruntime_USE_NCCL) # MPI is INDEPENDENT of NCCL for now. You can build NCLL without MPI and launch multi-GPU with your own launcher. if (onnxruntime_USE_MPI) if (EXISTS "${onnxruntime_MPI_HOME}") diff --git a/cmake/deps.txt b/cmake/deps.txt index 96c183909bcb..72469603a088 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -12,7 +12,7 @@ # NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI. # See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 # -abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240116.0.zip;bc2cec6baaad67fcb6c0c38972b687d4797927e9 +abseil_cpp;https://github.com/abseil/abseil-cpp/archive/f46495ea96f68fc3f6c394f099b2992743f6ff7f.zip;0e2b6d1dc7f0a808d1e23f7dd985f7bc18d52cbc coremltools;https://github.com/apple/coremltools/archive/refs/tags/7.1.zip;f1bab0f30966f2e217d8e01207d518f230a1641a cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 304aa77f5473..54bddcbdcf00 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -82,7 +82,10 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp ) set(mlas_platform_preprocess_srcs @@ -350,9 +353,12 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp ) - set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") if (NOT APPLE) set(mlas_platform_srcs diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 402135adbdd8..05a50a55db40 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -192,32 +192,7 @@ if (onnxruntime_USE_TVM) endif() if (onnxruntime_USE_VSINPU) - add_definitions(-DUSE_VSINPU=1) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") - file(GLOB_RECURSE onnxruntime_providers_vsinpu_srcs - "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/builders/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/builders/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - ) - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vsinpu_srcs}) - add_library(onnxruntime_providers_vsinpu ${onnxruntime_providers_vsinpu_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_vsinpu - onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf-lite flatbuffers Boost::mp11 - safeint_interface nsync::nsync_cpp) - add_dependencies(onnxruntime_providers_vsinpu ${onnxruntime_EXTERNAL_DEPENDENCIES}) - set_target_properties(onnxruntime_providers_vsinpu PROPERTIES FOLDER "ONNXRuntime" LINKER_LANGUAGE CXX) - target_include_directories(onnxruntime_providers_vsinpu PRIVATE ${ONNXRUNTIME_ROOT} $ENV{TIM_VX_INSTALL}/include) - - find_library(TIMVX_LIBRARY NAMES tim-vx PATHS $ENV{TIM_VX_INSTALL}/lib NO_DEFAULT_PATH) - if(TIMVX_LIBRARY) - target_link_libraries(onnxruntime_providers_vsinpu PRIVATE ${TIMVX_LIBRARY}) - else() - message(FATAL_ERROR "Cannot find TIM-VX library!") - endif() - + include(onnxruntime_providers_vsinpu.cmake) endif() if (onnxruntime_USE_XNNPACK) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 3b48a40bf116..82c31ce6b6b4 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -33,19 +33,21 @@ ) - if (onnxruntime_CUDA_MINIMAL) - set(onnxruntime_providers_cuda_shared_srcs "") - else() + if (NOT onnxruntime_CUDA_MINIMAL) file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh" ) + else() + set(onnxruntime_providers_cuda_cu_srcs + "${ONNXRUNTIME_ROOT}/core/providers/cuda/math/unary_elementwise_ops_impl.cu" + ) endif() source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) # disable contrib ops conditionally - if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + if(NOT onnxruntime_DISABLE_CONTRIB_OPS AND NOT onnxruntime_CUDA_MINIMAL) if (NOT onnxruntime_ENABLE_ATEN) list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/aten_ops/aten_op.cc" diff --git a/cmake/onnxruntime_providers_rocm.cmake b/cmake/onnxruntime_providers_rocm.cmake index 38d4b52c97a8..71692ddb9391 100644 --- a/cmake/onnxruntime_providers_rocm.cmake +++ b/cmake/onnxruntime_providers_rocm.cmake @@ -49,7 +49,10 @@ find_library(RCCL_LIB rccl REQUIRED) find_library(ROCTRACER_LIB roctracer64 REQUIRED) - set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${RCCL_LIB} ${ROCTRACER_LIB}) + find_package(rocm_smi REQUIRED) + set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${ROCM_SMI_LIBRARY} ${RCCL_LIB} ${ROCTRACER_LIB}) + include_directories(${ROCM_SMI_INCLUDE_DIR}) + link_directories(${ROCM_SMI_LIB_DIR}) file(GLOB_RECURSE onnxruntime_providers_rocm_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.h" diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 90203216600f..3d46c139feea 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - + if(onnxruntime_DISABLE_CONTRIB_OPS) + message( FATAL_ERROR "To compile TensorRT execution provider contrib ops have to be enabled to dump an engine using com.microsoft:EPContext node." ) + endif() add_definitions(-DUSE_TENSORRT=1) if (onnxruntime_TENSORRT_PLACEHOLDER_BUILDER) add_definitions(-DORT_TENSORRT_PLACEHOLDER_BUILDER) @@ -154,8 +156,11 @@ # See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121 # However, starting from TRT 10 GA, nvonnxparser_static doesn't link against tensorrt libraries. # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER} and ${TENSORRT_LIBRARY_INFER_PLUGIN}. - set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) - + if(onnxruntime_CUDA_MINIMAL) + set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) + else() + set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) + endif() file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.h" "${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.cc" @@ -190,6 +195,9 @@ if (WIN32) target_compile_options(onnxruntime_providers_tensorrt INTERFACE /wd4456) endif() + if(onnxruntime_CUDA_MINIMAL) + target_compile_definitions(onnxruntime_providers_tensorrt PRIVATE USE_CUDA_MINIMAL=1) + endif() # Needed for the provider interface, as it includes training headers when training is enabled if (onnxruntime_ENABLE_TRAINING_OPS) diff --git a/cmake/onnxruntime_providers_vsinpu.cmake b/cmake/onnxruntime_providers_vsinpu.cmake new file mode 100644 index 000000000000..4b987fd1e424 --- /dev/null +++ b/cmake/onnxruntime_providers_vsinpu.cmake @@ -0,0 +1,37 @@ + add_definitions(-DUSE_VSINPU=1) + file(GLOB_RECURSE onnxruntime_providers_vsinpu_srcs + "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/builders/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/builders/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" + ) + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vsinpu_srcs}) + add_library(onnxruntime_providers_vsinpu ${onnxruntime_providers_vsinpu_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_vsinpu + onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf-lite flatbuffers Boost::mp11 + safeint_interface nsync::nsync_cpp) + add_dependencies(onnxruntime_providers_vsinpu ${onnxruntime_EXTERNAL_DEPENDENCIES}) + set_target_properties(onnxruntime_providers_vsinpu PROPERTIES FOLDER "ONNXRuntime" LINKER_LANGUAGE CXX) + target_include_directories(onnxruntime_providers_vsinpu PRIVATE ${ONNXRUNTIME_ROOT} $ENV{TIM_VX_INSTALL}/include) + + find_library(TIMVX_LIBRARY NAMES tim-vx PATHS $ENV{TIM_VX_INSTALL}/lib NO_DEFAULT_PATH) + if(NOT TIMVX_LIBRARY) + message(FATAL_ERROR "TIM-VX library is not found!") + endif() + + if(CMAKE_CROSSCOMPILING) + message(STATUS "VSINPU ep will be cross compiled.") + if(EXISTS "$ENV{VIVANTE_SDK_DIR}/drivers") + set(DRIVER_DIR "$ENV{VIVANTE_SDK_DIR}/drivers") + elseif(EXISTS "$ENV{VIVANTE_SDK_DIR}/lib") + set(DRIVER_DIR "$ENV{VIVANTE_SDK_DIR}/lib") + else() + message(FATAL_ERROR "Neither drivers nor lib directory exists in this VIVANTE_SDK_DIR.") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wl,-rpath-link ${DRIVER_DIR} ${TIMVX_LIBRARY}") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") + target_link_libraries(onnxruntime_providers_vsinpu PRIVATE ${TIMVX_LIBRARY}) + endif() diff --git a/cmake/patches/abseil/absl_windows.patch b/cmake/patches/abseil/absl_windows.patch index 584c49d61229..82983646527d 100644 --- a/cmake/patches/abseil/absl_windows.patch +++ b/cmake/patches/abseil/absl_windows.patch @@ -1,8 +1,43 @@ +diff --git a/absl/base/attributes.h b/absl/base/attributes.h +index 5ea5ee3e..f4949898 100644 +--- a/absl/base/attributes.h ++++ b/absl/base/attributes.h +@@ -559,7 +559,7 @@ + #undef ABSL_ATTRIBUTE_UNUSED + #define ABSL_ATTRIBUTE_UNUSED __attribute__((__unused__)) + #else +-#define ABSL_ATTRIBUTE_UNUSED ++#define ABSL_ATTRIBUTE_UNUSED [[maybe_unused]] + #endif + + // ABSL_ATTRIBUTE_INITIAL_EXEC +diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h +index d4fe8f5c..27418d13 100644 +--- a/absl/container/internal/raw_hash_set.h ++++ b/absl/container/internal/raw_hash_set.h +@@ -1924,7 +1924,7 @@ HashtablezInfoHandle SampleHashtablezInfo(size_t sizeof_slot, size_t sizeof_key, + // In SOO, we sample on the first insertion so if this is an empty SOO case + // (e.g. when reserve is called), then we still need to sample. + if (kSooEnabled && was_soo && c.size() == 0) { +- return Sample(sizeof_slot, sizeof_key, sizeof_value, SooCapacity()); ++ return Sample(sizeof_slot, sizeof_key, sizeof_value, (int16_t)SooCapacity()); + } + // For non-SOO cases, we sample whenever the capacity is increasing from zero + // to non-zero. +@@ -3525,7 +3525,7 @@ class raw_hash_set { + assert(is_soo()); + if (!ShouldSampleHashtablezInfo()) return HashtablezInfoHandle{}; + return Sample(sizeof(slot_type), sizeof(key_type), sizeof(value_type), +- SooCapacity()); ++ (int16_t)SooCapacity()); + } + + inline void destroy_slots() { diff --git a/absl/copts/GENERATED_AbseilCopts.cmake b/absl/copts/GENERATED_AbseilCopts.cmake -index a4ab1aa2..dfd13fd7 100644 +index da2282fe..4c7fc26f 100644 --- a/absl/copts/GENERATED_AbseilCopts.cmake +++ b/absl/copts/GENERATED_AbseilCopts.cmake -@@ -129,8 +129,6 @@ list(APPEND ABSL_MSVC_FLAGS +@@ -181,8 +181,6 @@ list(APPEND ABSL_MSVC_FLAGS "/wd4005" "/wd4068" "/wd4180" @@ -10,12 +45,12 @@ index a4ab1aa2..dfd13fd7 100644 - "/wd4267" "/wd4503" "/wd4800" - ) + "/DNOMINMAX" diff --git a/absl/copts/GENERATED_copts.bzl b/absl/copts/GENERATED_copts.bzl -index a6efc98e..8c4de8e7 100644 +index b9e0071e..dd8410ec 100644 --- a/absl/copts/GENERATED_copts.bzl +++ b/absl/copts/GENERATED_copts.bzl -@@ -130,8 +130,6 @@ ABSL_MSVC_FLAGS = [ +@@ -182,8 +182,6 @@ ABSL_MSVC_FLAGS = [ "/wd4005", "/wd4068", "/wd4180", @@ -23,12 +58,12 @@ index a6efc98e..8c4de8e7 100644 - "/wd4267", "/wd4503", "/wd4800", - ] + "/DNOMINMAX", diff --git a/absl/copts/copts.py b/absl/copts/copts.py -index e6e11949..0aa7d868 100644 +index 2d85ac74..4875d668 100644 --- a/absl/copts/copts.py +++ b/absl/copts/copts.py -@@ -115,10 +115,6 @@ MSVC_WARNING_FLAGS = [ +@@ -118,10 +118,6 @@ MSVC_WARNING_FLAGS = [ "/wd4068", # unknown pragma # qualifier applied to function type has no meaning; ignored "/wd4180", diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 13d925e0fc2e..44d2222dbce1 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -1357,7 +1357,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca OrtAllocatorType allocatorType, int identifier, OrtMemType memType, - out IntPtr /*(OrtMemoryInfo*)*/ allocatorInfo // memory ownership transfered to caller + out IntPtr /*(OrtMemoryInfo*)*/ allocatorInfo // memory ownership transferred to caller ); public static DOrtCreateMemoryInfo OrtCreateMemoryInfo; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index 163a2b394c4a..5946e9fb1b16 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -22,7 +22,7 @@ public enum OnnxValueType ONNX_TYPE_MAP = 3, // It's a map ONNX_TYPE_OPAQUE = 4, // It's an experimental Opaque object ONNX_TYPE_SPARSETENSOR = 5, // It's a Sparse Tensor - ONNX_TYPE_OPTIONAL = 6, // It's an optional type that designates anything above (except UNKOWN) + ONNX_TYPE_OPTIONAL = 6, // It's an optional type that designates anything above (except UNKNOWN) } /// @@ -31,7 +31,7 @@ public enum OnnxValueType /// The class implements IDisposable and must /// be disposed of, otherwise native resources will leak /// and will eventually cause the application to slow down or crash. - /// + /// /// If the OrtValue instance is constructed over a managed memory, and it is not /// disposed properly, the pinned memory will continue to be pinned and interfere /// with GC operation. @@ -72,7 +72,7 @@ internal OrtValue(IntPtr handle, OnnxValueType onnxValueType) /// Constructor. The newly constructed OrtValue takes ownership of the native OrtValue instance /// and disposes of it when the OrtValue instance is disposed. The instance will take ownership and will /// dispose of compositeMembers instances. - /// + /// /// This constructor can only throw if OnnxType is not specified. /// /// native ortValue handle @@ -189,10 +189,10 @@ public OrtValue GetValue(int index, OrtAllocator allocator) /// /// Returns a ReadOnlySpan over tensor native buffer that /// provides a read-only view. - /// + /// /// Note, that the memory may be device allocated and, therefore, not accessible from the CPU. /// To get memory descriptor use GetTensorMemoryInfo(). - /// + /// /// OrtValue must contain a non-string tensor. /// The span is valid as long as the OrtValue instance is alive (not disposed). /// @@ -210,10 +210,10 @@ public ReadOnlySpan GetTensorDataAsSpan() where T : unmanaged /// This enables you to safely and efficiently modify the underlying /// native buffer in a type-safe manner. This is useful for example in IOBinding scenarios /// where you want to modify results of the inference and feed it back as input. - /// + /// /// Note, that the memory may be device allocated. /// To get memory descriptor use GetTensorMemoryInfo(). - /// + /// /// OrtValue must contain a non-string tensor. /// The span is valid as long as the OrtValue instance is alive (not disposed). /// @@ -237,7 +237,7 @@ public Span GetTensorMutableRawData() /// /// Fetch string tensor element buffer pointer at the specified index, /// convert/copy to UTF-16 char[] and return a ReadOnlyMemory instance. - /// + /// /// Obtain TensorTypeAndShape to get shape and element count. /// /// flat string tensor element index @@ -256,7 +256,7 @@ public ReadOnlyMemory GetStringElementAsMemory(int index) /// /// Fetch string tensor element buffer pointer at the specified index, /// copy/convert UTF-8 into a UTF-16 string and return it. - /// + /// /// Obtain TensorTypeAndShape to get shape and element count. /// /// flat string tensor element index @@ -279,7 +279,7 @@ public string GetStringElement(int index) /// /// Get a span over the native memory of the string tensor element. /// The span is valid as long as the OrtValue is valid. - /// + /// /// This is useful if you want to perform your own UTF-8 decoding or /// you do not care about decoding. /// Obtain TensorTypeAndShape to get shape and element count. @@ -483,7 +483,7 @@ private Span GetTensorBufferRawData(Type requestedType) /// This can be a piece of arbitrary memory that may be allocated by OrtAllocator (possibly on a device), /// a chunk of managed memory (must be pinned for the duration of OrtValue lifetime) or a memory that is allocated /// natively allocated using Marshal.AllocHGlobal(), stackalloc or other means (may be on a device). - /// + /// /// The resulting OrtValue does not own the underlying memory buffer and will not attempt to /// deallocate it. The caller must make sure that the memory remains valid for the duration of OrtValue lifetime. /// @@ -769,12 +769,12 @@ out IntPtr valueHandle /// Converts the string argument represented by ReadOnlySpan to UTF-8, /// allocates space in the native tensor and copies it into the native tensor memory. /// Typically, this is used to populate a new empty string tensor element. - /// + /// /// The number of elements is according to the shape supplied to CreateTensorWithEmptyStrings(). /// However, this API can also be used to overwrite any existing element within the string tensor. - /// + /// /// In general, to obtain the number of elements for any tensor, use GetTensorTypeAndShape() which - /// would return a disposable instance of TensorTypeAndShapeInfo. + /// would return a disposable instance of TensorTypeAndShapeInfo. /// Then call GetElementCount() or GetShape(). /// /// ReadOnlySpan over chars @@ -795,12 +795,12 @@ public void StringTensorSetElementAt(ReadOnlySpan str, int index) /// Converts the string argument represented by ReadOnlyMemory to UTF-8, /// allocates space in the native tensor and copies it into the native tensor memory. /// Typically, this is used to populate a new empty string tensor element. - /// + /// /// The number of elements is according to the shape supplied to CreateTensorWithEmptyStrings(). /// However, this API can also be used to overwrite any existing element within the string tensor. - /// + /// /// In general, to obtain the number of elements for any tensor, use GetTensorTypeAndShape() which - /// would return a disposable instance of TensorTypeAndShapeInfo. + /// would return a disposable instance of TensorTypeAndShapeInfo. /// Then call GetElementCount() or GetShape(). /// /// @@ -815,7 +815,7 @@ public void StringTensorSetElementAt(ReadOnlyMemory rom, int index) /// /// This API resizes String Tensor element to the requested amount of bytes (UTF-8) /// and copies the bytes from the supplied ReadOnlySpan into the native tensor memory (resized buffer). - /// + /// /// The API is useful for quick loading of utf8 data into the native tensor memory. /// /// read only span of bytes @@ -841,7 +841,7 @@ public void StringTensorSetElementAt(ReadOnlySpan utf8Bytes, int index) /// Creates an OrtValue that contains a string tensor. /// String tensors are always allocated on CPU. /// String data will be converted to UTF-8 and copied to native memory. - /// + /// /// Note, this is different from creating an OrtValue from other primitive data types /// where memory is pinned (if necessary) and the OrtValue points to that chunk of memory. /// @@ -885,10 +885,10 @@ public static OrtValue CreateFromStringTensor(Tensor tensor) /// Creates a sequence of OrtValues from a collection of OrtValues. /// All OrtValues in the collection must be of the same Onnx type. /// I.e. (Tensor, SparseTensor, Map, Sequence, etc.) - /// + /// /// The ortValues that are passed as argument are taken possession of by the newly /// created OrtValue. The caller should not dispose them, unless this call fails. - /// + /// /// The ortValues would be empty on successful return. /// /// a collection of OrtValues. On success the ortValues contained in the list @@ -978,24 +978,24 @@ public void ProcessSequence(SequenceElementVisitor visitor, OrtAllocator allocat /// Creates a map OrtValue with keys and values. /// On a high level the Onnxruntime representation of the map always consists of two /// OrtValues, keys and values. - /// + /// /// According to ONNX standard map keys can be unmanaged types only (or strings). /// Those keys are contained in a single tensor within OrtValue keys. - /// + /// /// Map values, on the other hand, can be composite types. The values parameter /// can either contain a single tensor with unmanaged map values with the same number of /// elements as the keys, or it can be a sequence of OrtValues, /// each of those can be a composite type (tensor, sequence, map). If it is a sequence, /// then the number of elements must match the number of elements in keys. - /// + /// /// Keys and values must be in the same order. - /// + /// /// ORT supports only a subset of types for keys and values, however, this API does not /// restrict it. - /// + /// /// The ortValues that are passed as argument are taken possession of by the newly /// created OrtValue. The caller should not dispose them, unless this call fails. - /// + /// /// Keys and values arguments will be set to null on success. /// /// Contains keys @@ -1031,10 +1031,10 @@ public static OrtValue CreateMap(ref OrtValue keys, ref OrtValue values) /// This API helps to quickly creates a map OrtValue with unmanaged (primitive) keys and values specified as arrays. /// This helps the user not to create OrtValues for keys and values separately and deal only with the final result. /// The map would consist of two tensors, one for keys and one for values. - /// + /// /// The OrtValues would be created on top of the managed memory arrays and use it directly. /// The number of elements in keys and values must be the same and they must be in order. - /// + /// /// The types must be unmanaged. /// /// keys type @@ -1078,10 +1078,10 @@ public static OrtValue CreateMap(K[] keys, V[] values) where K : unmanaged /// This helps the user not to create OrtValues for keys and values separately. /// The number of elements in keys and values must be the same and they must be in order. /// The map would consist of two tensors, one for keys and one for values. - /// + /// /// string keys would be converted to UTF-8 encoding and copied to an allocated native memory. /// The OrtValue for values would be created on top of the managed memory using it directly. - /// + /// /// The values type must be unmanaged. /// /// @@ -1128,13 +1128,13 @@ public static OrtValue CreateMapWithStringKeys(IReadOnlyCollection ke /// /// Creates a map OrtValue with non-string keys and string values. - /// + /// /// This helps the user not to create OrtValues for keys and values separately. /// The number of elements in keys and values must be the same and they must be in order. - /// + /// /// The OrtValue for keys would be created on top of the managed memory using it directly. /// string values would be converted to UTF-8 encoding and copied to an allocated native memory. - /// + /// /// /// unmanaged type of keys /// @@ -1182,17 +1182,17 @@ public static OrtValue CreateMapWithStringValues(K[] keys, IReadOnlyCollectio /// Typically, when one uses GetValue() API, it creates a copy of OrtValue /// that points to the same buffer as keys or values. This API helps to deal with those /// temporary instances and avoid leaks. - /// + /// /// According to ONNX standard map keys can be unmanaged types only (or strings). /// Those keys are contained in a single tensor within OrtValue keys. So you can query those /// directly from keys argument. - /// + /// /// Map values, on the other hand, can be composite types. The values parameter /// can either contain a single tensor with unmanaged map values with the same number of /// elements as the keys, or it can be a sequence of OrtValues, /// each of those can be a composite type (tensor, sequence, map). If it is a sequence, /// then the number of elements must match the number of elements in keys. - /// + /// /// Depending on the structure of the values, one will either directly query a single tensor /// from values, or will have to iterate over the sequence of OrtValues and visit each of those /// resulting in a recursive visitation. @@ -1204,7 +1204,7 @@ public static OrtValue CreateMapWithStringValues(K[] keys, IReadOnlyCollectio /// /// This API helps the user to process a map OrtValue without /// having to deal with the lifespan of intermediate OrtValues. - /// + /// /// each API value is fed to the vistor functor. /// /// visitor function diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index d6a6b9627f41..0892e17fc97b 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -116,7 +116,7 @@ public void TestSessionOptions() var directml_dll_path = AppDomain.CurrentDomain.BaseDirectory; SetDllDirectory(directml_dll_path); - + try { opt.AppendExecutionProvider_DML(0); @@ -124,7 +124,7 @@ public void TestSessionOptions() catch (OnnxRuntimeException ortException) { // if we run on a CI machine with the incorrect hardware we might get an error due to that. - // allow that as the call made it through to the DML EP so the C# layer is working correctly. + // allow that as the call made it through to the DML EP so the C# layer is working correctly. // any other exception type or error message is considered a failure. Assert.Contains("The specified device interface or feature level is not supported on this system.", ortException.Message); @@ -1895,7 +1895,7 @@ private void TestSharedAllocatorUsingCreateAndRegisterAllocator() sessionOptions.AddSessionConfigEntry("session.use_env_allocators", "1"); // Create two sessions to share the allocator - // Create a thrid session that DOES NOT use the allocator in the environment + // Create a third session that DOES NOT use the allocator in the environment using (var session1 = new InferenceSession(model, sessionOptions)) using (var session2 = new InferenceSession(model, sessionOptions)) using (var session3 = new InferenceSession(model)) // Use the default SessionOptions instance @@ -2127,7 +2127,7 @@ private void TestLoadAzureEP() } catch (Exception) { Assert.True(false); - } + } } } } diff --git a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs index 9370a03f7fbe..a005efa749a1 100644 --- a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs +++ b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs @@ -32,7 +32,7 @@ class CommandOptions [Option('i', "input_file", Required = false, HelpText = "Input file.")] public string InputFile { get; set; } - [Option('p', Required = false, HelpText = "Run with parallel exection. Default is false")] + [Option('p', Required = false, HelpText = "Run with parallel execution. Default is false")] public bool ParallelExecution { get; set; } = false; [Option('o', "optimization_level", Required = false, HelpText = "Optimization Level. Default is 99, all optimization.")] diff --git a/include/onnxruntime/core/common/gsl.h b/include/onnxruntime/core/common/gsl.h deleted file mode 100644 index 371c5b7543b5..000000000000 --- a/include/onnxruntime/core/common/gsl.h +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "gsl/gsl" diff --git a/include/onnxruntime/core/common/logging/capture.h b/include/onnxruntime/core/common/logging/capture.h index 2af050918706..13d3a3ad17af 100644 --- a/include/onnxruntime/core/common/logging/capture.h +++ b/include/onnxruntime/core/common/logging/capture.h @@ -4,7 +4,7 @@ #pragma once #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/code_location.h" #include "core/common/logging/severity.h" diff --git a/include/onnxruntime/core/common/span_utils.h b/include/onnxruntime/core/common/span_utils.h index b2d1aefee9c0..9f7454625fcd 100644 --- a/include/onnxruntime/core/common/span_utils.h +++ b/include/onnxruntime/core/common/span_utils.h @@ -6,7 +6,7 @@ #include #include -#include "core/common/gsl.h" +#include namespace onnxruntime { diff --git a/include/onnxruntime/core/eager/ort_kernel_invoker.h b/include/onnxruntime/core/eager/ort_kernel_invoker.h index 1d1046742db4..fcf92de2ee39 100644 --- a/include/onnxruntime/core/eager/ort_kernel_invoker.h +++ b/include/onnxruntime/core/eager/ort_kernel_invoker.h @@ -24,7 +24,10 @@ class ORTInvoker { public: ORTInvoker(std::shared_ptr execution_provider, const logging::Logger& logger, - const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries) : execution_provider_(std::move(execution_provider)), logger_(logger), custom_op_registries_(custom_op_registries) { + const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries) + : execution_provider_(std::move(execution_provider)), + logger_(logger), + custom_op_registries_(custom_op_registries) { if (!execution_provider_) { ORT_THROW("Execution provider is nullptr"); } diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index b197d8809043..87feefa10ca4 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -9,7 +9,7 @@ #include #include #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/exceptions.h" #include "core/framework/endian.h" diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 16ad943a5f47..49c3d1bdd088 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -231,7 +231,7 @@ class IExecutionProvider { const std::reference_wrapper filtered_graph; }; - // Fusion approach that is suppported + // Fusion approach that is supported // !!! The "Function" FusionStyle is deprecated. // !!! If your EP is using this fusion style, please migrate it to "FilteredGraphViewer" style. enum class FusionStyle { diff --git a/include/onnxruntime/core/framework/int4.h b/include/onnxruntime/core/framework/int4.h index aff365dc9738..0282b84bd0f8 100644 --- a/include/onnxruntime/core/framework/int4.h +++ b/include/onnxruntime/core/framework/int4.h @@ -6,7 +6,7 @@ #include #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 94c6d81ee932..07625c38d847 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -25,7 +25,7 @@ #include "core/graph/constants.h" #include "core/graph/graph_viewer.h" #include "core/graph/onnx_protobuf.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { class OpKernelContext; } @@ -94,7 +94,7 @@ class OpKernel { } // Override this function to return a list of attributes the session can safely remove - // after it is intialized and saved. This option is useful to reduce memory usage + // after it is initialized and saved. This option is useful to reduce memory usage // when the kernel does not reuse the operator attributes but copies them. // All attributes returned by this method will be removed by method // PruneRemovableAttributes of they exists. diff --git a/include/onnxruntime/core/framework/op_kernel_info.h b/include/onnxruntime/core/framework/op_kernel_info.h index a0bbfe50a700..1510cdc9d145 100644 --- a/include/onnxruntime/core/framework/op_kernel_info.h +++ b/include/onnxruntime/core/framework/op_kernel_info.h @@ -8,7 +8,7 @@ #include "core/framework/ort_value.h" #include "core/framework/op_node_proto_helper.h" #include "core/graph/graph_viewer.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { diff --git a/include/onnxruntime/core/framework/op_node_proto_helper.h b/include/onnxruntime/core/framework/op_node_proto_helper.h index e7ac01947af4..5cbaaa0212c5 100644 --- a/include/onnxruntime/core/framework/op_node_proto_helper.h +++ b/include/onnxruntime/core/framework/op_node_proto_helper.h @@ -7,7 +7,7 @@ #include "core/common/status.h" #include "core/framework/tensor_shape.h" #include "core/graph/graph_viewer.h" -#include "core/common/gsl.h" +#include #endif class IMLOpKernel; diff --git a/include/onnxruntime/core/framework/stream_handles.h b/include/onnxruntime/core/framework/stream_handles.h index 26d78133b52f..9c987f10ccad 100644 --- a/include/onnxruntime/core/framework/stream_handles.h +++ b/include/onnxruntime/core/framework/stream_handles.h @@ -54,14 +54,14 @@ class Stream { // update its lookup table with the table snapshot in notification. // The memory reusing strategy is: // A kernel in current stream is safe to reuse another stream's memory chunk - // as long as the reused chunk's timestamp is less than the last synchonized + // as long as the reused chunk's timestamp is less than the last synchronized // timestamp recorded in the lookup table. // Get the current timestamp uint64_t GetCurrentTimestamp() const { return timestamp_; } // return the timestamp when the last synchronization happened between target stream and current stream. - // return 0 if no synchonization happened. + // return 0 if no synchronization happened. // if target_stream is nullptr, it means it is a sequence running on device doesn't support Stream (i.e. CPU) // we can safely return 0 in that case to save a lookup. uint64_t GetLastSyncTimestampWithTargetStream(Stream* target_stream) const { diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index 96725aa10306..dd2603d214f6 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -9,7 +9,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/framework/allocator.h" #include "core/framework/tensor_shape.h" diff --git a/include/onnxruntime/core/framework/tensor_shape.h b/include/onnxruntime/core/framework/tensor_shape.h index 82a1c1de8352..d4ee4a0e5e64 100644 --- a/include/onnxruntime/core/framework/tensor_shape.h +++ b/include/onnxruntime/core/framework/tensor_shape.h @@ -9,7 +9,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers_fwd.h" #include "core/common/span_utils.h" #include "onnxruntime_config.h" diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 7dabe42ba0a2..9289e14c17dd 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -14,7 +14,7 @@ #include "core/common/flatbuffers.h" -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/path_string.h" diff --git a/include/onnxruntime/core/optimizer/graph_transformer_config.h b/include/onnxruntime/core/optimizer/graph_transformer_config.h index c112d9b0480a..6af48331270c 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_config.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_config.h @@ -13,7 +13,7 @@ struct GraphTransformerConfiguration { /* * Cast propagation strategy. * One strategy is to insert casts around all the nodes with the allowed opcodes - * and reduce, by removing redundent-casts and back-to-back-casts etc., and + * and reduce, by removing redundant-casts and back-to-back-casts etc., and * the other is to propagate casts using flood-fill approach, expanding float16 regions in the graph * traversing the graph up/down. */ @@ -70,4 +70,4 @@ constexpr GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy constexpr bool operator==(GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy, GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy); constexpr bool operator!=(GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy, GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy); -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 7104e70c3a8a..9ada01673d4d 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -12,14 +12,16 @@ #define ORT_CUDA_CTX -#include "cuda_resource.h" -#include "core/providers/custom_op_context.h" #include #include #ifndef USE_CUDA_MINIMAL #include #include #endif + +#include "core/providers/cuda/cuda_resource.h" +#include "core/providers/custom_op_context.h" + namespace Ort { namespace Custom { diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java index 3915292648ae..1be8c22b40da 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java @@ -349,7 +349,7 @@ private SessionOptions parseSessionOptions(ReadableMap options) throws OrtExcept if (options.hasKey("interOpNumThreads")) { int interOpNumThreads = options.getInt("interOpNumThreads"); if (interOpNumThreads > 0 && interOpNumThreads < Integer.MAX_VALUE) { - sessionOptions.setIntraOpNumThreads(interOpNumThreads); + sessionOptions.setInterOpNumThreads(interOpNumThreads); } } diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 725d11b9d54c..8d077846fa6a 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -21,7 +21,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Ceil | ai.onnx(7-12, 13+) | ceil | ✓ | ✓ | | | Clip | ai.onnx(7-10, 11, 12, 13+) | clamp | ✓ | ✓ | WebNN CPU backend only supports 3 specific ranges: [0.0, infinity], [-1.0, 1.0], [0.0, 6.0] (Chromium issue: https://issues.chromium.org/issues/326156496) | | Concat | ai.onnx(7-10, 11-12, 13+) | concat | ✓ | ✓ | | -| Conv | ai.onnx(7-10, 11+) | conv2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU requires the 'W' (weight) input to be a constant | +| Conv | ai.onnx(7-10, 11+) | conv2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight) | | ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✗ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). | | Cos | ai.onnx(7+) | cos | ✓ | ✓ | | | Div | ai.onnx(7-12, 13, 14+) | div | ✓ | ✓ | | diff --git a/js/web/lib/onnxjs/execution-plan.ts b/js/web/lib/onnxjs/execution-plan.ts index 5599087ab46f..e155ff123f79 100644 --- a/js/web/lib/onnxjs/execution-plan.ts +++ b/js/web/lib/onnxjs/execution-plan.ts @@ -92,7 +92,7 @@ export class ExecutionPlan { const inputTensors = inputList as Tensor[]; Logger.verbose( 'ExecPlan', - `Runing op:${thisOp.node.name} (${ + `Running op:${thisOp.node.name} (${ inputTensors.map((t, i) => `'${thisOp.node.inputs[i]}': ${t.type}[${t.dims.join(',')}]`).join(', ')})`); const outputList = await this.profiler.event( diff --git a/js/web/lib/onnxjs/graph.ts b/js/web/lib/onnxjs/graph.ts index f16da4281595..d444be2bf7ce 100644 --- a/js/web/lib/onnxjs/graph.ts +++ b/js/web/lib/onnxjs/graph.ts @@ -674,7 +674,7 @@ class GraphImpl implements Graph, Graph.Transformer { } /** - * Delete the specifed node. Assume the node has one incoming input and the first output connected to other nodes. + * Delete the specified node. Assume the node has one incoming input and the first output connected to other nodes. * An input validation must be done before calling this function. * @param nodeIndex The index of node to be deleted */ diff --git a/js/web/lib/onnxjs/util.ts b/js/web/lib/onnxjs/util.ts index d697a8b3138c..22c4e4c755f5 100644 --- a/js/web/lib/onnxjs/util.ts +++ b/js/web/lib/onnxjs/util.ts @@ -474,7 +474,7 @@ export class ProtoUtil { export class LongUtil { // This function is called to get a number from long type of data for attribute, dim, and ir version, // which values are signed integers. - // To make it more generic, add an optional paramter to convert to a unsigned number. + // To make it more generic, add an optional parameter to convert to a unsigned number. static longToNumber(n: Long|flatbuffers.Long|number, unsigned?: boolean) { if (Long.isLong(n)) { return n.toNumber(); diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts new file mode 100644 index 000000000000..f8a1e1966fd4 --- /dev/null +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -0,0 +1,401 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +interface NavigatorML { + readonly ml: ML; +} +interface Navigator extends NavigatorML {} +interface WorkerNavigator extends NavigatorML {} +type MLDeviceType = 'cpu'|'gpu'|'npu'; +type MLPowerPreference = 'default'|'high-performance'|'low-power'; +interface MLContextOptions { + deviceType?: MLDeviceType; + powerPreference?: MLPowerPreference; + numThreads?: number; +} +interface ML { + createContext(options?: MLContextOptions): Promise; + createContext(gpuDevice: GPUDevice): Promise; +} +type MLNamedArrayBufferViews = Record; +interface MLComputeResult { + inputs?: MLNamedArrayBufferViews; + outputs?: MLNamedArrayBufferViews; +} +interface MLContext { + compute(graph: MLGraph, inputs: MLNamedArrayBufferViews, outputs: MLNamedArrayBufferViews): Promise; +} +interface MLGraph {} +type MLInputOperandLayout = 'nchw'|'nhwc'; +type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8'; +interface MLOperandDescriptor { + dataType: MLOperandDataType; + dimensions?: number[]; +} +interface MLOperand { + dataType(): MLOperandDataType; + shape(): number[]; +} +interface MLActivation {} +type MLNamedOperands = Record; +interface MLGraphBuilder { + // eslint-disable-next-line @typescript-eslint/no-misused-new + new(context: MLContext): MLGraphBuilder; + input(name: string, descriptor: MLOperandDescriptor): MLOperand; + constant(descriptor: MLOperandDescriptor, bufferView: ArrayBufferView): MLOperand; + constant(type: MLOperandDataType, value: number): MLOperand; + build(outputs: MLNamedOperands): Promise; +} +interface MLArgMinMaxOptions { + axes?: number[]; + keepDimensions?: boolean; + selectLastIndex?: boolean; +} +interface MLGraphBuilder { + argMin(input: MLOperand, options?: MLArgMinMaxOptions): MLOperand; + argMax(input: MLOperand, options?: MLArgMinMaxOptions): MLOperand; +} +interface MLBatchNormalizationOptions { + scale?: MLOperand; + bias?: MLOperand; + axis?: number; + epsilon?: number; +} +interface MLGraphBuilder { + batchNormalization(input: MLOperand, mean: MLOperand, variance: MLOperand, options?: MLBatchNormalizationOptions): + MLOperand; +} +interface MLGraphBuilder { + cast(input: MLOperand, type: MLOperandDataType): MLOperand; +} +interface MLClampOptions { + minValue?: number; + maxValue?: number; +} +interface MLGraphBuilder { + clamp(input: MLOperand, options?: MLClampOptions): MLOperand; + clamp(options?: MLClampOptions): MLActivation; +} +interface MLGraphBuilder { + concat(inputs: MLOperand[], axis: number): MLOperand; +} +type MLConv2dFilterOperandLayout = 'oihw'|'hwio'|'ohwi'|'ihwo'; +interface MLConv2dOptions { + padding?: number[]; + strides?: number[]; + dilations?: number[]; + groups?: number; + inputLayout?: MLInputOperandLayout; + filterLayout?: MLConv2dFilterOperandLayout; + bias?: MLOperand; +} +interface MLGraphBuilder { + conv2d(input: MLOperand, filter: MLOperand, options?: MLConv2dOptions): MLOperand; +} +type MLConvTranspose2dFilterOperandLayout = 'iohw'|'hwoi'|'ohwi'; +interface MLConvTranspose2dOptions { + padding?: number[]; + strides?: number[]; + dilations?: number[]; + outputPadding?: number[]; + outputSizes?: number[]; + groups?: number; + inputLayout?: MLInputOperandLayout; + filterLayout?: MLConvTranspose2dFilterOperandLayout; + bias?: MLOperand; +} +interface MLGraphBuilder { + convTranspose2d(input: MLOperand, filter: MLOperand, options?: MLConvTranspose2dOptions): MLOperand; +} +interface MLGraphBuilder { + add(a: MLOperand, b: MLOperand): MLOperand; + sub(a: MLOperand, b: MLOperand): MLOperand; + mul(a: MLOperand, b: MLOperand): MLOperand; + div(a: MLOperand, b: MLOperand): MLOperand; + max(a: MLOperand, b: MLOperand): MLOperand; + min(a: MLOperand, b: MLOperand): MLOperand; + pow(a: MLOperand, b: MLOperand): MLOperand; +} +interface MLGraphBuilder { + equal(a: MLOperand, b: MLOperand): MLOperand; + greater(a: MLOperand, b: MLOperand): MLOperand; + greaterOrEqual(a: MLOperand, b: MLOperand): MLOperand; + lesser(a: MLOperand, b: MLOperand): MLOperand; + lesserOrEqual(a: MLOperand, b: MLOperand): MLOperand; + logicalNot(a: MLOperand): MLOperand; +} +interface MLGraphBuilder { + abs(input: MLOperand): MLOperand; + ceil(input: MLOperand): MLOperand; + cos(input: MLOperand): MLOperand; + erf(input: MLOperand): MLOperand; + exp(input: MLOperand): MLOperand; + floor(input: MLOperand): MLOperand; + identity(input: MLOperand): MLOperand; + log(input: MLOperand): MLOperand; + neg(input: MLOperand): MLOperand; + reciprocal(input: MLOperand): MLOperand; + sin(input: MLOperand): MLOperand; + sqrt(input: MLOperand): MLOperand; + tan(input: MLOperand): MLOperand; +} +interface MLEluOptions { + alpha?: number; +} +interface MLGraphBuilder { + elu(input: MLOperand, options?: MLEluOptions): MLOperand; + elu(options?: MLEluOptions): MLActivation; +} +interface MLGraphBuilder { + expand(input: MLOperand, newShape: number[]): MLOperand; +} +interface MLGatherOptions { + axis?: number; +} +interface MLGraphBuilder { + gather(input: MLOperand, indices: MLOperand, options?: MLGatherOptions): MLOperand; +} +interface MLGraphBuilder { + gelu(input: MLOperand): MLOperand; + gelu(): MLActivation; +} +interface MLGemmOptions { + c?: MLOperand; + alpha?: number; + beta?: number; + aTranspose?: boolean; + bTranspose?: boolean; +} +interface MLGraphBuilder { + gemm(a: MLOperand, b: MLOperand, options?: MLGemmOptions): MLOperand; +} +type MLGruWeightLayout = 'zrn'|'rzn'; +type MLRecurrentNetworkDirection = 'forward'|'backward'|'both'; +interface MLGruOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + initialHiddenState?: MLOperand; + resetAfter?: boolean; + returnSequence?: boolean; + direction?: MLRecurrentNetworkDirection; + layout?: MLGruWeightLayout; + activations?: MLActivation[]; +} +interface MLGraphBuilder { + gru(input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, steps: number, hiddenSize: number, + options?: MLGruOptions): MLOperand[]; +} +interface MLGruCellOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + resetAfter?: boolean; + layout?: MLGruWeightLayout; + activations?: MLActivation[]; +} +interface MLGraphBuilder { + gruCell( + input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, hiddenState: MLOperand, hiddenSize: number, + options?: MLGruCellOptions): MLOperand; +} +interface MLHardSigmoidOptions { + alpha?: number; + beta?: number; +} +interface MLGraphBuilder { + hardSigmoid(input: MLOperand, options?: MLHardSigmoidOptions): MLOperand; + hardSigmoid(options?: MLHardSigmoidOptions): MLActivation; +} +interface MLGraphBuilder { + hardSwish(input: MLOperand): MLOperand; + hardSwish(): MLActivation; +} +interface MLInstanceNormalizationOptions { + scale?: MLOperand; + bias?: MLOperand; + epsilon?: number; + layout?: MLInputOperandLayout; +} +interface MLGraphBuilder { + instanceNormalization(input: MLOperand, options?: MLInstanceNormalizationOptions): MLOperand; +} +interface MLLayerNormalizationOptions { + scale?: MLOperand; + bias?: MLOperand; + axes?: number[]; + epsilon?: number; +} +interface MLGraphBuilder { + layerNormalization(input: MLOperand, options?: MLLayerNormalizationOptions): MLOperand; +} +interface MLLeakyReluOptions { + alpha?: number; +} +interface MLGraphBuilder { + leakyRelu(input: MLOperand, options?: MLLeakyReluOptions): MLOperand; + leakyRelu(options?: MLLeakyReluOptions): MLActivation; +} +interface MLLinearOptions { + alpha?: number; + beta?: number; +} +interface MLGraphBuilder { + linear(input: MLOperand, options?: MLLinearOptions): MLOperand; + linear(options?: MLLinearOptions): MLActivation; +} +type MLLstmWeightLayout = 'iofg'|'ifgo'; +interface MLLstmOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + peepholeWeight?: MLOperand; + initialHiddenState?: MLOperand; + initialCellState?: MLOperand; + returnSequence?: boolean; + direction?: MLRecurrentNetworkDirection; + layout?: MLLstmWeightLayout; + activations?: MLActivation[]; +} +interface MLGraphBuilder { + lstm( + input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, steps: number, hiddenSize: number, + options?: MLLstmOptions): MLOperand[]; +} +interface MLLstmCellOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + peepholeWeight?: MLOperand; + layout?: MLLstmWeightLayout; + activations?: MLActivation[]; +} +interface MLGraphBuilder { + lstmCell( + input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, hiddenState: MLOperand, cellState: MLOperand, + hiddenSize: number, options?: MLLstmCellOptions): MLOperand[]; +} +interface MLGraphBuilder { + matmul(a: MLOperand, b: MLOperand): MLOperand; +} +type MLPaddingMode = 'constant'|'edge'|'reflection'|'symmetric'; +interface MLPadOptions { + mode?: MLPaddingMode; + value?: number; +} +interface MLGraphBuilder { + pad(input: MLOperand, beginningPadding: number[], endingPadding: number[], options?: MLPadOptions): MLOperand; +} +type MLRoundingType = 'floor'|'ceil'; +interface MLPool2dOptions { + windowDimensions?: number[]; + padding?: number[]; + strides?: number[]; + dilations?: number[]; + layout?: MLInputOperandLayout; + roundingType?: MLRoundingType; + outputSizes?: number[]; +} +interface MLGraphBuilder { + averagePool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand; + l2Pool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand; + maxPool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand; +} +interface MLGraphBuilder { + prelu(input: MLOperand, slope: MLOperand): MLOperand; +} +interface MLReduceOptions { + axes?: number[]; + keepDimensions?: boolean; +} +interface MLGraphBuilder { + reduceL1(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceL2(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceLogSum(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceLogSumExp(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceMax(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceMean(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceMin(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceProduct(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceSum(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceSumSquare(input: MLOperand, options?: MLReduceOptions): MLOperand; +} +interface MLGraphBuilder { + relu(input: MLOperand): MLOperand; + relu(): MLActivation; +} +type MLInterpolationMode = 'nearest-neighbor'|'linear'; +interface MLResample2dOptions { + mode?: MLInterpolationMode; + scales?: number[]; + sizes?: number[]; + axes?: number[]; +} +interface MLGraphBuilder { + resample2d(input: MLOperand, options?: MLResample2dOptions): MLOperand; +} +interface MLGraphBuilder { + reshape(input: MLOperand, newShape: number[]): MLOperand; +} +interface MLGraphBuilder { + sigmoid(input: MLOperand): MLOperand; + sigmoid(): MLActivation; +} +interface MLGraphBuilder { + slice(input: MLOperand, starts: number[], sizes: number[]): MLOperand; +} +interface MLGraphBuilder { + softmax(input: MLOperand, axis: number): MLOperand; + softmax(axis: number): MLActivation; +} +interface MLGraphBuilder { + softplus(input: MLOperand): MLOperand; + softplus(): MLActivation; +} +interface MLGraphBuilder { + softsign(input: MLOperand): MLOperand; + softsign(): MLActivation; +} +interface MLSplitOptions { + axis?: number; +} +interface MLGraphBuilder { + split(input: MLOperand, splits: number|number[], options?: MLSplitOptions): MLOperand[]; +} +interface MLGraphBuilder { + tanh(input: MLOperand): MLOperand; + tanh(): MLActivation; +} +interface MLTransposeOptions { + permutation?: number[]; +} +interface MLGraphBuilder { + transpose(input: MLOperand, options?: MLTransposeOptions): MLOperand; +} +interface MLTriangularOptions { + upper?: boolean; + diagonal?: number; +} +interface MLGraphBuilder { + triangular(input: MLOperand, options?: MLTriangularOptions): MLOperand; +} +interface MLGraphBuilder { + where(condition: MLOperand, input: MLOperand, other: MLOperand): MLOperand; +} + +// Experimental MLBuffer interface + +type MLSize64Out = number; +interface MLBuffer { + readonly size: MLSize64Out; + destroy(): void; +} +type MLSize64 = number; +interface MLBufferDescriptor { + size: MLSize64; +} +type MLNamedBuffers = Record; +interface MLContext { + createBuffer(descriptor: MLBufferDescriptor): MLBuffer; + writeBuffer( + dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: MLSize64, + srcElementSize?: MLSize64): void; + readBuffer(srcBuffer: MLBuffer): Promise; + dispatch(graph: MLGraph, inputs: MLNamedBuffers, outputs: MLNamedBuffers): void; +} diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 4d2b80e31a47..f289fc20bba4 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -66,8 +66,6 @@ const setExecutionProviders = const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption; // const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; - const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads; - const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference; if (deviceType) { const keyDataOffset = allocWasmString('deviceType', allocs); const valueDataOffset = allocWasmString(deviceType, allocs); @@ -76,26 +74,6 @@ const setExecutionProviders = checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`); } } - if (numThreads !== undefined) { - // Just ignore invalid webnnOptions.numThreads. - const validatedNumThreads = - (typeof numThreads !== 'number' || !Number.isInteger(numThreads) || numThreads < 0) ? 0 : - numThreads; - const keyDataOffset = allocWasmString('numThreads', allocs); - const valueDataOffset = allocWasmString(validatedNumThreads.toString(), allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== - 0) { - checkLastError(`Can't set a session config entry: 'numThreads' - ${numThreads}.`); - } - } - if (powerPreference) { - const keyDataOffset = allocWasmString('powerPreference', allocs); - const valueDataOffset = allocWasmString(powerPreference, allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== - 0) { - checkLastError(`Can't set a session config entry: 'powerPreference' - ${powerPreference}.`); - } - } } break; case 'webgpu': diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index a483ff09f000..9fc8786192c5 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -1,6 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; import {SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; @@ -69,7 +74,7 @@ const initOrt = (numThreads: number, loggingLevel: number): void => { }; /** - * intialize runtime environment. + * initialize runtime environment. * @param env passed in the environment config object. */ export const initRuntime = async(env: Env): Promise => { @@ -253,11 +258,43 @@ export const createSession = async( await Promise.all(loadingPromises); } + for (const provider of options?.executionProviders ?? []) { + const providerName = typeof provider === 'string' ? provider : provider.name; + if (providerName === 'webnn') { + if (wasm.currentContext) { + throw new Error('WebNN execution provider is already set.'); + } + if (typeof provider !== 'string') { + const webnnOptions = provider as InferenceSession.WebNNExecutionProviderOption; + const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; + const gpuDevice = (webnnOptions as InferenceSession.WebNNOptionsWebGpu)?.gpuDevice; + const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; + const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads; + const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference; + if (context) { + wasm.currentContext = context as MLContext; + } else if (gpuDevice) { + wasm.currentContext = await navigator.ml.createContext(gpuDevice); + } else { + wasm.currentContext = await navigator.ml.createContext({deviceType, numThreads, powerPreference}); + } + } else { + wasm.currentContext = await navigator.ml.createContext(); + } + break; + } + } + sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); if (sessionHandle === 0) { checkLastError('Can\'t create a session.'); } + // clear current MLContext after session creation + if (wasm.currentContext) { + wasm.currentContext = undefined; + } + const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); const enableGraphCapture = !!options?.enableGraphCapture; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 9ced89651e84..70728c82e775 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -1,6 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + import type {Tensor} from 'onnxruntime-common'; /* eslint-disable @typescript-eslint/naming-convention */ @@ -19,7 +24,7 @@ export declare namespace JSEP { type CaptureEndFunction = () => void; type ReplayFunction = () => void; - export interface Module extends WebGpuModule { + export interface Module extends WebGpuModule, WebNnModule { /** * Mount the external data file to an internal map, which will be used during session initialization. * @@ -106,6 +111,13 @@ export declare namespace JSEP { */ jsepOnReleaseSession: (sessionId: number) => void; } + + export interface WebNnModule { + /** + * Active MLContext used to create WebNN EP. + */ + currentContext: MLContext; + } } export interface OrtInferenceAPIs { diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index 6718dcb639a4..fbde81524cce 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -613,7 +613,7 @@ async function main() { // == The Problem == // every time when a test is completed, it will be added to the recovery page list. // if we run the test 100 times, there will be 100 previous tabs when we launch Edge again. - // this run out of resources quickly and fails the futher test. + // this run out of resources quickly and fails the further test. // and it cannot recover by itself because every time it is terminated forcely or crashes. // and the auto recovery feature has no way to disable by configuration/commandline/registry // diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/attention_mechanism.h b/onnxruntime/contrib_ops/cpu/attnlstm/attention_mechanism.h index 37ef36ac911b..c0ae5f8ceba9 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/attention_mechanism.h +++ b/onnxruntime/contrib_ops/cpu/attnlstm/attention_mechanism.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc b/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc index 63ec5be8c290..72c5a813e3d7 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc +++ b/onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc @@ -57,7 +57,7 @@ void AttentionWrapper::ProcessOutput(const gsl::span& rnn_cell_outpu if (has_attn_layer_) { // concat([p_cell_output, context]) * stack([attn_layer_cell_weights_, attn_layer_attn_weights_]) = // p_cell_output * attn_layer_cell_weights_ + context * attn_layer_attn_weights_ - // The first part is calulated above. Here just add the later. + // The first part is calculated above. Here just add the later. math::GemmEx(CblasNoTrans, CblasNoTrans, batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0}, attn_context_.data(), attn_context_depth_, diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.h b/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.h index aad077489186..ce91760516cb 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.h +++ b/onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.h @@ -11,7 +11,7 @@ #include "core/common/logging/logging.h" #include "core/framework/allocator.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/cdist.cc b/onnxruntime/contrib_ops/cpu/cdist.cc index d0ed81a9a6dc..736dbcfede2f 100644 --- a/onnxruntime/contrib_ops/cpu/cdist.cc +++ b/onnxruntime/contrib_ops/cpu/cdist.cc @@ -67,7 +67,7 @@ static void CalculateSqeuclidean(const Tensor& a, const Tensor& b, Tensor& c, co threadpool); #else // the performance of this isn't great as the eigen matmul is single threaded by default - // if you're on x86 and care about performance try MKL first. if there's a good enough argument for optimising this + // if you're on x86 and care about performance try MKL first. if there's a good enough argument for optimizing this // we can look into it in the future. ORT_UNUSED_PARAMETER(threadpool); diff --git a/onnxruntime/contrib_ops/cpu/crop.h b/onnxruntime/contrib_ops/cpu/crop.h index 0fd0a5c49b3b..3b72ef429c1f 100644 --- a/onnxruntime/contrib_ops/cpu/crop.h +++ b/onnxruntime/contrib_ops/cpu/crop.h @@ -6,7 +6,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/inverse.cc b/onnxruntime/contrib_ops/cpu/inverse.cc index 355b036e36d0..54bd99d20957 100644 --- a/onnxruntime/contrib_ops/cpu/inverse.cc +++ b/onnxruntime/contrib_ops/cpu/inverse.cc @@ -53,7 +53,7 @@ struct Inverse::ComputeImpl { void operator()(const Tensor* input, Tensor* output, int64_t batch_num, int64_t rows, int64_t cols) const { auto batch_offset = batch_num * rows * cols; - // Direct cast to half as it just as MLFloat16 containes only uint16_t + // Direct cast to half as it just as MLFloat16 contains only uint16_t const auto* input_data = reinterpret_cast(input->Data() + batch_offset); auto* output_data = reinterpret_cast(output->MutableData() + batch_offset); diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc index ed6071b40feb..de1798e54874 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc @@ -15,7 +15,7 @@ #include "core/mlas/inc/mlas.h" #include "core/platform/threadpool.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 688b7d6341ae..12fae5ccf098 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -26,7 +26,7 @@ #include "core/framework/TensorSeq.h" #include "core/framework/allocator.h" #include "core/framework/ort_value.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/beam_search.h" #include "contrib_ops/cpu/transformers/logits_processor.h" #include "contrib_ops/cpu/transformers/sequences.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 8c1ceec62fec..3bdb274d7d5a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -8,7 +8,7 @@ #include "core/providers/cpu/math/softmax_shared.h" #include "core/providers/cpu/generator/random.h" #include "core/common/safeint.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/sequences.h" #include "contrib_ops/cpu/transformers/beam_search_scorer.h" #include "contrib_ops/cpu/transformers/generation_device_helper.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index 8f778c57bb41..7f99a808f442 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -10,7 +10,7 @@ #endif #include -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/logits_processor.h" #include "contrib_ops/cpu/transformers/generation_shared.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 2b8b26f0a06f..30bf3aa0a121 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -5,7 +5,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/allocator.h" #include "core/framework/ort_value.h" #include "contrib_ops/cpu/utils/debug_macros.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc index 788eab1b672d..a107889afd76 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc @@ -26,7 +26,7 @@ #include "core/framework/session_options.h" #include "core/framework/TensorSeq.h" #include "core/framework/ort_value.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/greedy_search.h" #include "contrib_ops/cpu/transformers/logits_processor.h" #include "contrib_ops/cpu/transformers/sequences.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.h b/onnxruntime/contrib_ops/cpu/transformers/sequences.h index 99c9474a2ca4..440a07e14a6c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/generation_shared.h" namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc index 83aa99ff4d50..d675ba742e03 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc @@ -7,7 +7,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/subgraph_base.h" #include "contrib_ops/cpu/utils/dump_tensor.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h index 487a35c55a85..bde591626bb8 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h @@ -5,7 +5,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/allocator.h" #include "core/framework/feeds_fetches_manager.h" #include "contrib_ops/cpu/transformers/generation_device_helper.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc index 443d69d49470..34a1da99316a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc @@ -6,7 +6,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/subgraph_gpt.h" #include "contrib_ops/cpu/utils/dump_tensor.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 4264ceff042f..9037e58aaf31 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -6,7 +6,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" #include "contrib_ops/cpu/utils/dump_tensor.h" #include "contrib_ops/cpu/transformers/generation_device_helper.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc index 887a2b5cc519..51473c0c931b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc @@ -6,7 +6,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h" namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc index f3da01c952f5..bf866d67ffc0 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc @@ -6,7 +6,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" #include "contrib_ops/cpu/transformers/subgraph_whisper_decoder.h" #include "contrib_ops/cpu/utils/dump_tensor.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc index 8480edc405e5..ff5f256e7bb7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc @@ -6,7 +6,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h" #include "contrib_ops/cpu/transformers/subgraph_whisper_encoder.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index fda7ac278412..56836bdda197 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -5,7 +5,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/allocator.h" #include "contrib_ops/cpu/bert/attention_common.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index ceee17c2a2d0..ee49f362564a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -282,7 +282,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { &one, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop, UseTF32())); // gemm_query_buffer in col-base: (h2, S*B) - // calcualte k, v + // calculate k, v n = 2 * hidden_size; k = hidden_size; if (!has_layer_state_ || !use_past_) { diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu index ae53eca541fa..8a17e945df3f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu @@ -141,7 +141,7 @@ __global__ void EmbedLayerNormKernel( } __syncthreads(); - // 2. load pos/segment/word embeddings and add them toghether + // 2. load pos/segment/word embeddings and add them together // offset into embeddings is given by word_id * hidden_size const int position_offset = position_id * hidden_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index e04cdf369c6a..90f0b94cafce 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -107,7 +107,7 @@ void set_params_fprop(Flash_fwd_params& params, params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; - // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API separates // local and causal, meaning when we have local window size params.is_causal = is_causal; if (is_causal && (window_size_left >= 0 || window_size_right != 0)) { diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu index d610051c77e5..2c251246267b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu @@ -148,7 +148,7 @@ __launch_bounds__(blockSize) // following the Softmax. // // For now zero-out only [row_index - 2*attention_window, row_index + 2*attention_window], - // we can even be more agressive and reduce the zeroing out window size since + // we can even be more aggressive and reduce the zeroing out window size since // each row has entries in 3 blocks (3*attention_window size instead of 4*attention_window) int zero_start = row_index - 2 * attention_window; if (zero_start < 0) { diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h index 2c9dc3689f88..116b9fb80da4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h @@ -41,7 +41,7 @@ struct Gmem_params { // Hidden dim per head int32_t d; - // array of length b+1 holding prefix sum of actual sequence lenghts. + // array of length b+1 holding prefix sum of actual sequence lengths. int32_t* cu_seqlens; }; @@ -69,7 +69,7 @@ struct Fused_multihead_attention_params_mhca { // See https://confluence.nvidia.com/pages/viewpage.action?pageId=302779721 for details. bool enable_i2f_trick; - // array of length b+1 holding prefix sum of actual sequence lenghts + // array of length b+1 holding prefix sum of actual sequence lengths int32_t* cu_seqlens; // use C/32 Format. diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc index 20c936e1b671..e1549ad29566 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc @@ -4,7 +4,7 @@ #include "sharding_spec.h" #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/framework/tensor_shape.h" #include diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h index 5abc50a61c9a..5b47273bc8f2 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h @@ -37,7 +37,7 @@ class DeviceMesh { // corresponding sharding spec is a string "S[1]S[0]". // If that 2-D tensor's value is np.array([[5, 6], [7, 8]]), // GPU 0/1/2/3 owns 5/7/6/8. Below is a visualization the sharding - // proccess. + // process. // - Start with a 2-D device mesh [[0, 1], [2, 3]] and // a 2-D tensor [[5, 6], [7, 8]] // - GPU: [[0, 1], [2, 3]], Tensor: [[5, 6], [7, 8]] diff --git a/onnxruntime/contrib_ops/cuda/math/fft_ops.cc b/onnxruntime/contrib_ops/cuda/math/fft_ops.cc index 4b524dcf795a..65bec758ae52 100644 --- a/onnxruntime/contrib_ops/cuda/math/fft_ops.cc +++ b/onnxruntime/contrib_ops/cuda/math/fft_ops.cc @@ -87,7 +87,7 @@ Status FFTBase::DoFFT(OpKernelContext* context, const Tensor* X, bool complex int64_t batch_size = (batch_ndim == 0 ? 1 : input_shape.SizeToDimension(batch_ndim)); // infer output shape - // copy the input shape up to the second last dimention + // copy the input shape up to the second last dimension std::vector output_dims, signal_dims; int i = 0; for (; i < batch_ndim + signal_ndim_ - 1; ++i) { diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index 28ab27ee33d1..07c5de2fe8d8 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -207,7 +207,7 @@ Status GemmFloat8::ComputeGemm( #endif case CUDA_R_8F_E4M3: case CUDA_R_8F_E5M2: - compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + compute_type = CUBLAS_COMPUTE_32F; break; #endif default: @@ -219,7 +219,7 @@ Status GemmFloat8::ComputeGemm( compute_type = CUBLAS_COMPUTE_32F_FAST_16BF; break; case CUDA_R_32F: - compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + compute_type = CUBLAS_COMPUTE_32F; break; default: ORT_THROW("Unable to determine computeType in operator GemmFloat8."); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 09d2dba7d203..e047bd948434 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -53,7 +53,7 @@ namespace GenerationCudaDeviceHelper { // e.g In the case of past(fp32) -> cast to fp16 -> Attention(fp16), the reorder // function will use the fp32 chunk size and cause the model silently generates // the incorrect results. -// TODO: Fix this issue. Either retrive the Attention op type from the graph or +// TODO: Fix this issue. Either retrieve the Attention op type from the graph or // check the type of past state as graph input should be same as Attention op type. // It might be better to forcefully require the same type since cast node generates // extra overhead. diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h index 6b712ccfbeb7..0fe2d7ccb1f7 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h @@ -7,7 +7,7 @@ #include "core/providers/cpu/tensor/utils.h" #include "core/providers/cuda/cuda_common.h" -#include "core/common/gsl.h" +#include #include "contrib_ops/cpu/transformers/sequences.h" #include "contrib_ops/cpu/transformers/generation_shared.h" diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 92c780d4a9d4..7a16eb38181a 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -360,7 +360,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { max_thr_per_blk)); // NOTE: ORT: seqlens_k Indicates past sequence lengths for token generation case. - // we should call fmha with total sequence lenghts + // we should call fmha with total sequence lengths seqlens_k_tmp = GetScratchBuffer(batch_size * sizeof(int), ctx->GetComputeStream()); ORT_RETURN_IF_ERROR(LaunchSeqlensInc(hip_stream, seqlens_k_ptr, seqlens_k_tmp.get(), batch_size, sequence_length)); seqlens_k_ptr = seqlens_k_tmp.get(); diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu index 1e175b37b02d..b65841b35964 100644 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu @@ -78,7 +78,7 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { auto m = !transA_ ? a_shape[0] : a_shape[1]; auto k = !transA_ ? a_shape[1] : a_shape[0]; - ORT_ENFORCE(k == (!transB_ ? b_shape[0] : b_shape[1])); // k is compatiable + ORT_ENFORCE(k == (!transB_ ? b_shape[0] : b_shape[1])); // k is compatible auto n = !transB_ ? b_shape[1] : b_shape[0]; TensorShapeVector output_shape = {m, n}; diff --git a/onnxruntime/core/codegen/mti/mti_tvm_utils.h b/onnxruntime/core/codegen/mti/mti_tvm_utils.h index e85489dc1cac..c2a14106c168 100644 --- a/onnxruntime/core/codegen/mti/mti_tvm_utils.h +++ b/onnxruntime/core/codegen/mti/mti_tvm_utils.h @@ -5,7 +5,7 @@ #include #include -#include "core/common/gsl.h" +#include #include #include "core/codegen/mti/common.h" diff --git a/onnxruntime/core/codegen/mti/tensor/concat_ops.cc b/onnxruntime/core/codegen/mti/tensor/concat_ops.cc index 625a91bc6456..3394d5b7e00a 100644 --- a/onnxruntime/core/codegen/mti/tensor/concat_ops.cc +++ b/onnxruntime/core/codegen/mti/tensor/concat_ops.cc @@ -4,7 +4,7 @@ #include "core/codegen/mti/tensor/concat_ops.h" #include "core/codegen/mti/mti_tvm_utils.h" -#include "core/common/gsl.h" +#include #include namespace onnxruntime { diff --git a/onnxruntime/core/codegen/mti/tensor/gather.cc b/onnxruntime/core/codegen/mti/tensor/gather.cc index 3ea6ebf46620..152b3981f162 100644 --- a/onnxruntime/core/codegen/mti/tensor/gather.cc +++ b/onnxruntime/core/codegen/mti/tensor/gather.cc @@ -4,7 +4,7 @@ #include "core/codegen/mti/tensor/gather.h" #include "core/codegen/mti/mti_tvm_utils.h" -#include "core/common/gsl.h" +#include #include namespace onnxruntime { diff --git a/onnxruntime/core/codegen/mti/tensor/slice.cc b/onnxruntime/core/codegen/mti/tensor/slice.cc index 7c73be2b5234..6cbab43584d4 100644 --- a/onnxruntime/core/codegen/mti/tensor/slice.cc +++ b/onnxruntime/core/codegen/mti/tensor/slice.cc @@ -5,7 +5,7 @@ #include "core/codegen/mti/mti_tvm_utils.h" #include -#include "core/common/gsl.h" +#include #include #include diff --git a/onnxruntime/core/codegen/mti/tensor/split.cc b/onnxruntime/core/codegen/mti/tensor/split.cc index 8dbbd8fdcc28..6ee366314858 100644 --- a/onnxruntime/core/codegen/mti/tensor/split.cc +++ b/onnxruntime/core/codegen/mti/tensor/split.cc @@ -4,7 +4,7 @@ #include "core/codegen/mti/tensor/split.h" #include "core/codegen/mti/mti_tvm_utils.h" -#include "core/common/gsl.h" +#include #include namespace onnxruntime { diff --git a/onnxruntime/core/codegen/mti/tensor/tile.cc b/onnxruntime/core/codegen/mti/tensor/tile.cc index 60ef29f7ce70..2fef86adcbae 100644 --- a/onnxruntime/core/codegen/mti/tensor/tile.cc +++ b/onnxruntime/core/codegen/mti/tensor/tile.cc @@ -3,7 +3,7 @@ #include "core/codegen/mti/tensor/tile.h" #include "core/codegen/mti/mti_tvm_utils.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace tvm_codegen { diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/nn/lstm.cc b/onnxruntime/core/codegen/passes/op_ir_creator/nn/lstm.cc index 5c2557142dd0..88170bb56dd2 100644 --- a/onnxruntime/core/codegen/passes/op_ir_creator/nn/lstm.cc +++ b/onnxruntime/core/codegen/passes/op_ir_creator/nn/lstm.cc @@ -9,7 +9,7 @@ namespace onnxruntime { namespace tvm_codegen { -// In the cell computation, we don't have the "direction" dimention and sequence dimension, +// In the cell computation, we don't have the "direction" dimension and sequence dimension, // which have been processed outside of the cell. // Here we implement an LTSM cell. // For those args (inputs/outputs) of hidden states we put AFTER regular args (inputs/outputs) diff --git a/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc b/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc index 888dddfd3dbd..55892974aa33 100644 --- a/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc +++ b/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc @@ -7,7 +7,7 @@ #include "core/codegen/passes/utils/codegen_context.h" #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" -#include "core/common/gsl.h" +#include #include diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 0bee36e4d10b..4c9e7e80db49 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -61,7 +61,7 @@ class CPUIDInfo { /** * @brief Some ARMv8 power efficient core has narrower 64b load/store - * that needs specialized optimiztion in kernels + * that needs specialized optimization in kernels * @return whether the indicated core has narrower load/store device */ bool IsCoreArmv8NarrowLd(uint32_t coreId) const { @@ -73,7 +73,7 @@ class CPUIDInfo { /** * @brief Some ARMv8 power efficient core has narrower 64b load/store - * that needs specialized optimiztion in kernels + * that needs specialized optimization in kernels * @return whether the current core has narrower load/store device */ bool IsCurrentCoreArmv8NarrowLd() const { diff --git a/onnxruntime/core/common/helper.cc b/onnxruntime/core/common/helper.cc index 7b7073634989..6a52db73df10 100644 --- a/onnxruntime/core/common/helper.cc +++ b/onnxruntime/core/common/helper.cc @@ -56,7 +56,7 @@ void PrintFinalMessage(const char* msg) { #else // TODO, consider changing the output of the error message from std::cerr to logging when the // exceptions are disabled, since using std::cerr might increase binary size, and std::cerr output - // might not be easily accesible on some systems such as mobile + // might not be easily accessible on some systems such as mobile // TODO, see if we need to change the output of the error message from std::cerr to NSLog for iOS std::cerr << msg << std::endl; #endif diff --git a/onnxruntime/core/common/logging/capture.cc b/onnxruntime/core/common/logging/capture.cc index 3c23e15e5cc0..ac0d4d5fc707 100644 --- a/onnxruntime/core/common/logging/capture.cc +++ b/onnxruntime/core/common/logging/capture.cc @@ -3,7 +3,7 @@ #include "core/common/logging/capture.h" #include "core/common/logging/logging.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace logging { diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index 10e117267e14..7b62de799b6f 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -636,7 +636,7 @@ bool ThreadPool::ShouldParallelize(const concurrency::ThreadPool* tp) { } int ThreadPool::DegreeOfParallelism(const concurrency::ThreadPool* tp) { - // When not using OpenMP, we parallelise over the N threads created by the pool + // When not using OpenMP, we parallelize over the N threads created by the pool // tp, plus 1 for the thread entering a loop. if (tp) { if (tp->force_hybrid_ || CPUIDInfo::GetCPUIDInfo().IsHybrid()) { diff --git a/onnxruntime/core/flatbuffers/flatbuffers_utils.cc b/onnxruntime/core/flatbuffers/flatbuffers_utils.cc index 1eb3bbdb1237..42dff12eaa2d 100644 --- a/onnxruntime/core/flatbuffers/flatbuffers_utils.cc +++ b/onnxruntime/core/flatbuffers/flatbuffers_utils.cc @@ -4,7 +4,7 @@ #include "core/flatbuffers/flatbuffers_utils.h" #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/flatbuffers/schema/ort.fbs.h" #include "core/graph/constants.h" #include "core/graph/onnx_protobuf.h" diff --git a/onnxruntime/core/framework/data_transfer_utils.h b/onnxruntime/core/framework/data_transfer_utils.h index d54df49eeb9d..eeec329544bc 100644 --- a/onnxruntime/core/framework/data_transfer_utils.h +++ b/onnxruntime/core/framework/data_transfer_utils.h @@ -5,7 +5,7 @@ #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/framework/tensor.h" diff --git a/onnxruntime/core/framework/endian_utils.h b/onnxruntime/core/framework/endian_utils.h index b83977c1ac67..6f084d058d00 100644 --- a/onnxruntime/core/framework/endian_utils.h +++ b/onnxruntime/core/framework/endian_utils.h @@ -5,7 +5,7 @@ #include -#include "core/common/gsl.h" +#include #include "core/common/status.h" #include "core/common/common.h" diff --git a/onnxruntime/core/framework/fallback_cpu_capability.h b/onnxruntime/core/framework/fallback_cpu_capability.h index 7c8f91c7dad3..c5bcd22888b7 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.h +++ b/onnxruntime/core/framework/fallback_cpu_capability.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers_fwd.h" #include "core/framework/execution_provider.h" // for IExecutionProvider::IKernelLookup #include "core/graph/graph_viewer.h" diff --git a/onnxruntime/core/framework/kernel_lookup.h b/onnxruntime/core/framework/kernel_lookup.h index 2b4d3ce81623..0dd17d2f4a62 100644 --- a/onnxruntime/core/framework/kernel_lookup.h +++ b/onnxruntime/core/framework/kernel_lookup.h @@ -4,7 +4,7 @@ #pragma once #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/framework/execution_provider.h" // for IExecutionProvider::IKernelLookup #include "core/framework/kernel_registry.h" #include "core/framework/kernel_type_str_resolver.h" diff --git a/onnxruntime/core/framework/kernel_registry_manager.h b/onnxruntime/core/framework/kernel_registry_manager.h index 1868583f41ba..201fda6d978b 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.h +++ b/onnxruntime/core/framework/kernel_registry_manager.h @@ -7,7 +7,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers.h" #include "core/common/status.h" #include "core/framework/kernel_type_str_resolver.h" diff --git a/onnxruntime/core/framework/kernel_type_str_resolver.h b/onnxruntime/core/framework/kernel_type_str_resolver.h index fea2a6ef3a43..587be491b360 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver.h +++ b/onnxruntime/core/framework/kernel_type_str_resolver.h @@ -13,7 +13,7 @@ #include "core/graph/onnx_protobuf.h" #endif // !defined(ORT_MINIMAL_BUILD) -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers.h" #include "core/common/status.h" #include "core/graph/op_identifier.h" diff --git a/onnxruntime/core/framework/kernel_type_str_resolver_utils.h b/onnxruntime/core/framework/kernel_type_str_resolver_utils.h index 3d06013e4fe7..5daab7c1159b 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver_utils.h +++ b/onnxruntime/core/framework/kernel_type_str_resolver_utils.h @@ -5,7 +5,7 @@ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) -#include "core/common/gsl.h" +#include #include "core/common/status.h" #include "core/framework/kernel_type_str_resolver.h" #include "core/graph/op_identifier.h" diff --git a/onnxruntime/core/framework/op_node_proto_helper.cc b/onnxruntime/core/framework/op_node_proto_helper.cc index c3deb94300e7..ca9b74eafe4d 100644 --- a/onnxruntime/core/framework/op_node_proto_helper.cc +++ b/onnxruntime/core/framework/op_node_proto_helper.cc @@ -6,7 +6,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" #include "core/graph/op.h" -#include "core/common/gsl.h" +#include using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 13da26d5e605..46bfc3630303 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -8,7 +8,7 @@ #include #include #include -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers.h" #include "core/framework/config_options.h" #include "core/framework/ort_value.h" diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index e318c9a8238c..b1a7504b283c 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -10,7 +10,7 @@ #include "core/common/flatbuffers.h" -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/inlined_containers.h" diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 77323f268a27..e8086877a915 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -11,7 +11,7 @@ #include #endif -#include "core/common/gsl.h" +#include #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/common/span_utils.h" diff --git a/onnxruntime/core/framework/transpose_helper.h b/onnxruntime/core/framework/transpose_helper.h index c34d5ef3f27f..e33044117f89 100644 --- a/onnxruntime/core/framework/transpose_helper.h +++ b/onnxruntime/core/framework/transpose_helper.h @@ -37,7 +37,7 @@ We fall back to the default implementation in all other cases, and if the input #include "core/framework/tensor.h" #include "core/platform/threadpool.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { bool IsTransposeMovingSingleAxis(gsl::span permutations, size_t& from, size_t& to); diff --git a/onnxruntime/core/graph/contrib_ops/onnx_deprecated_operators.cc b/onnxruntime/core/graph/contrib_ops/onnx_deprecated_operators.cc index be39a2d7ec2b..b1b7cf346a27 100644 --- a/onnxruntime/core/graph/contrib_ops/onnx_deprecated_operators.cc +++ b/onnxruntime/core/graph/contrib_ops/onnx_deprecated_operators.cc @@ -395,7 +395,7 @@ ONNX_CONTRIB_OPERATOR_SET_SCHEMA( const auto input_rank = input_shape.dim_size(); if (input_rank != 4) fail_shape_inference("Input's shape must be 4-D"); - // parse necessary attributes for futher processing + // parse necessary attributes for further processing std::vector border; bool border_present = getRepeatedAttribute(ctx, "border", border); if (!border_present || border.size() != 4) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 67451301023e..f73a50db7aaa 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -11,7 +11,7 @@ #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers.h" #include "core/common/logging/logging.h" #include "core/common/narrow.h" @@ -98,7 +98,8 @@ static Status MergeShapeInfo(const std::string& output_name, #endif } ORT_CATCH(const ONNX_NAMESPACE::InferenceError& ex) { - // if this model was not created with the latest onnx version, allow the shape inferencing failure (strict == false). + // if this model was not created with the latest onnx version, allow the shape inferencing failure + // (strict == false). // we do this to have strict testing of the latest inferencing to detect bugs, but lenient shape inferencing for // older models in case later changes to the ONNX shape inferencing or ORT break them. if (!strict) { @@ -114,7 +115,8 @@ static Status MergeShapeInfo(const std::string& output_name, } #if !defined(DISABLE_OPTIONAL_TYPE) else if (utils::HasOptionalTensorType(source)) { - ONNX_NAMESPACE::UnionShapeInfo(utils::GetShape(source), *utils::GetMutableOptionalTypeProto(target)->mutable_tensor_type()); + ONNX_NAMESPACE::UnionShapeInfo(utils::GetShape(source), + *utils::GetMutableOptionalTypeProto(target)->mutable_tensor_type()); } #endif @@ -401,7 +403,8 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu const auto& input_tensor_elem_type = input_tensor_type.elem_type(); const auto& current_tensor_elem_type = current_type.tensor_type().elem_type(); - ORT_RETURN_IF_ERROR(OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, override_types)); + ORT_RETURN_IF_ERROR(OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, + override_types)); if (utils::HasShape(input_tensor_type)) { if (utils::HasShape(current_type)) { @@ -420,7 +423,8 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu const auto input_tensor_elem_type = input_tensor_type.elem_type(); const auto current_tensor_elem_type = current_type.sparse_tensor_type().elem_type(); - ORT_RETURN_IF_ERROR(OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, override_types)); + ORT_RETURN_IF_ERROR(OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, + override_types)); if (utils::HasShape(input_tensor_type)) { if (utils::HasShape(current_type)) { @@ -440,7 +444,8 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu // Check for homogeneity within optional type if (is_input_type_optional_tensor_type != is_current_type_optional_tensor_type) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Optional Type mismatch. Expected: ", ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(current_type), + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Optional Type mismatch. Expected: ", + ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(current_type), " . Got: ", ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(input_type)); } @@ -453,7 +458,8 @@ common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& inpu const auto& input_tensor_elem_type = input_tensor_type.elem_type(); const auto& current_tensor_elem_type = optional_current_type.tensor_type().elem_type(); - ORT_RETURN_IF_ERROR(OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, override_types)); + ORT_RETURN_IF_ERROR(OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, + override_types)); if (utils::HasShape(optional_input_type.tensor_type())) { if (utils::HasShape(optional_current_type.tensor_type())) { @@ -1203,7 +1209,8 @@ Graph::Graph(const Model& owning_model, #if !defined(DISABLE_SPARSE_TENSORS) if (node.attribute(0).type() == AttributeProto_AttributeType_SPARSE_TENSOR) { auto p = sparse_tensor_names_.emplace(tensor->name()); - ORT_ENFORCE(p.second, "Duplicate constant node sparse initializer name: '", tensor->name(), "' Model is invalid."); + ORT_ENFORCE(p.second, "Duplicate constant node sparse initializer name: '", tensor->name(), + "' Model is invalid."); } #endif } @@ -1533,8 +1540,10 @@ void Graph::AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_ *dst_arg_pointer = src_arg; } - nodes_[src_node_index]->MutableRelationships().output_edges.insert(Node::EdgeEnd(*nodes_[dst_node_index], src_arg_slot, dst_arg_slot)); - nodes_[dst_node_index]->MutableRelationships().input_edges.insert(Node::EdgeEnd(*nodes_[src_node_index], src_arg_slot, dst_arg_slot)); + nodes_[src_node_index]->MutableRelationships().output_edges.insert(Node::EdgeEnd(*nodes_[dst_node_index], + src_arg_slot, dst_arg_slot)); + nodes_[dst_node_index]->MutableRelationships().input_edges.insert(Node::EdgeEnd(*nodes_[src_node_index], + src_arg_slot, dst_arg_slot)); } void Graph::RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_slot, int dst_arg_slot) { @@ -1573,13 +1582,15 @@ void Graph::RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int s ORT_THROW("Argument mismatch when removing edge."); } - nodes_[dst_node_index]->MutableRelationships().input_edges.erase(Node::EdgeEnd(*nodes_[src_node_index], src_arg_slot, dst_arg_slot)); - nodes_[src_node_index]->MutableRelationships().output_edges.erase(Node::EdgeEnd(*nodes_[dst_node_index], src_arg_slot, dst_arg_slot)); + nodes_[dst_node_index]->MutableRelationships().input_edges.erase(Node::EdgeEnd(*nodes_[src_node_index], + src_arg_slot, dst_arg_slot)); + nodes_[src_node_index]->MutableRelationships().output_edges.erase(Node::EdgeEnd(*nodes_[dst_node_index], + src_arg_slot, dst_arg_slot)); } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) -GSL_SUPPRESS(es.84) // ignoring return value from unordered_map::insert causes noisy complaint +GSL_SUPPRESS(es .84) // ignoring return value from unordered_map::insert causes noisy complaint Status Graph::BuildConnections(std::unordered_set& outer_scope_node_args_consumed) { // recurse into subgraphs first so we can update any nodes in this graph that are used by those subgraphs if (!resolve_context_.nodes_with_subgraphs.empty()) { @@ -2356,7 +2367,7 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op, #endif // ENABLE_TRAINING -GSL_SUPPRESS(es.84) // noisy warning about ignoring return value from insert(...) +GSL_SUPPRESS(es .84) // noisy warning about ignoring return value from insert(...) Status Graph::PerformTopologicalSortAndCheckIsAcyclic() { nodes_in_topological_order_.clear(); std::unordered_set downstream_nodes; // nodes downstream of the node we're currently checking @@ -2429,7 +2440,8 @@ Status Graph::PerformTopologicalSortAndCheckIsAcyclic() { const NodeIndex idx = iter->Index(); // the input to this node is also downstream of this node if (downstream_nodes.find(idx) != downstream_nodes.end()) { - Status status(ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, "This is an invalid model. Error: the graph is not acyclic."); + Status status(ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, + "This is an invalid model. Error: the graph is not acyclic."); return status; } @@ -2444,7 +2456,8 @@ Status Graph::PerformTopologicalSortAndCheckIsAcyclic() { return Status::OK(); } - return Status(ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, "This is an invalid model. Error: the graph is not acyclic."); + return Status(ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, + "This is an invalid model. Error: the graph is not acyclic."); } bool FullyDefinedType(const TypeProto& type_proto) { @@ -2487,11 +2500,13 @@ bool FullyDefinedType(const TypeProto& type_proto) { // parameters are the Graph instance for the subgraph, the input types from the control flow node that contains // the subgraph, and the vector to write the output from the inferencing. using SubgraphInferencingFunc = - std::function&, std::vector&, const Graph::ResolveOptions&)>; + std::function&, std::vector&, + const Graph::ResolveOptions&)>; class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer { public: - GraphInferencerImpl(const Node& node, Graph& graph, SubgraphInferencingFunc& inferencing_func, const Graph::ResolveOptions& options) + GraphInferencerImpl(const Node& node, Graph& graph, SubgraphInferencingFunc& inferencing_func, + const Graph::ResolveOptions& options) : node_(node), graph_(graph), inferencing_func_(inferencing_func), options_(options) { } @@ -2720,7 +2735,7 @@ Status Graph::UpdateShapeInference(Node& node) { } // Implementation of type-inference and type-checking for a single node -GSL_SUPPRESS(f.23) // spurious warning about inferred_type never being checked for null +GSL_SUPPRESS(f .23) // spurious warning about inferred_type never being checked for null Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const ResolveOptions& options) { auto& node_name = node.Name(); @@ -2796,7 +2811,8 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso // The type-parameter T is bound to different values for different inputs. Status status(ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, "Type Error: Type parameter (" + op_formal_parameter.GetTypeStr() + - ") of Optype (" + op.Name() + ") bound to different types (" + *(param_to_type_iter->second) + + ") of Optype (" + op.Name() + ") bound to different types (" + + *(param_to_type_iter->second) + " and " + *(input_def->Type()) + " in node (" + node_name + ")."); return status; @@ -2930,7 +2946,8 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso *merge_target.mutable_sparse_tensor_type()->mutable_shape() = *output_def->Shape(); } #endif - auto status = MergeShapeInfo(output_def->Name(), onnx_inferred_type, merge_target, strict_shape_type_inference_, logger_); + auto status = MergeShapeInfo(output_def->Name(), onnx_inferred_type, merge_target, + strict_shape_type_inference_, logger_); if (!status.IsOK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node_name, " ", status.ErrorMessage()); } @@ -3176,7 +3193,7 @@ Status Graph::VerifyInputAndInitializerNames() { } for (auto& initializer_pair : name_to_initial_tensor_) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) inputs_and_initializers.insert(initializer_pair.first); // Initializers are expected to be included in inputs (according to ONNX spec). // onnxruntime relaxes this constraint. No duplicate-name check here. @@ -3412,7 +3429,8 @@ void Graph::RemoveInitializedTensor(const std::string& tensor_name) { SetGraphResolveNeeded(); } else { #if !defined(DISABLE_SPARSE_TENSORS) - ORT_ENFORCE(sparse_tensor_names_.count(tensor_name) == 0, "sparse_tensor_names_ not in sync with name_to_initial_tensor_"); + ORT_ENFORCE(sparse_tensor_names_.count(tensor_name) == 0, + "sparse_tensor_names_ not in sync with name_to_initial_tensor_"); #endif } @@ -3448,7 +3466,8 @@ Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initi return true; }; - ORT_RETURN_IF_NOT(!is_external || utils::HasExternalData(old_initializer), "Trying to replace non-external initializer with external data"); + ORT_RETURN_IF_NOT(!is_external || utils::HasExternalData(old_initializer), + "Trying to replace non-external initializer with external data"); ORT_RETURN_IF_NOT(dims_eq(), "Replacement tensor's dimensions do not match."); ORT_RETURN_IF_NOT(old_initializer.data_type() == new_initializer.data_type(), @@ -3526,7 +3545,8 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( ORT_ENFORCE(existing_entry != mutable_initializers.pointer_end(), "graph_proto_ is not in sync with name_to_initial_tensor_"); (**existing_entry).clear_data_location(); - const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(old_initializer.data_type())->GetElementType(); + const DataTypeImpl* const type = + DataTypeImpl::TensorTypeFromONNXEnum(old_initializer.data_type())->GetElementType(); TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(old_initializer); auto tensor = Tensor(type, tensor_shape, tensor_buffer, OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); @@ -4246,7 +4266,7 @@ void Graph::ComputeOverridableInitializers() { #if !defined(ORT_MINIMAL_BUILD) -GSL_SUPPRESS(es.84) // warning about ignoring return value from insert(...) +GSL_SUPPRESS(es .84) // warning about ignoring return value from insert(...) Status Graph::SetGraphInputsOutputs() { // If loaded from a model file, we start from the specified inputs and // outputs set earlier by InitializeStateFromModelFileGraphProto(). @@ -4432,7 +4452,7 @@ Status Graph::PopulateNodeArgToProducerConsumerLookupsFromNodes() { } // calling private ctor -GSL_SUPPRESS(r.11) +GSL_SUPPRESS(r .11) gsl::not_null Graph::AllocateNode() { ORT_ENFORCE(nodes_.size() < static_cast(std::numeric_limits::max())); std::unique_ptr new_node(new Node(nodes_.size(), *this)); @@ -4445,7 +4465,7 @@ gsl::not_null Graph::AllocateNode() { return gsl::not_null{node}; } -// TODO: Does this need (and maybe AllocateNode) to be threadsafe so nodes_ and num_of_nodes_ managed more carefully? +// TODO(s): Does this need (and maybe AllocateNode) to be threadsafe so nodes_ and num_of_nodes_ managed more carefully? bool Graph::ReleaseNode(NodeIndex index) { if (index >= nodes_.size()) { return false; @@ -4562,13 +4582,13 @@ void Graph::FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_n auto dst_idx = input_edge.GetDstArgIndex(); // if this input is an input of the fused node add an edge for that - if (dst_idx < (int)node->InputDefs().size()) { + if (dst_idx < static_cast(node->InputDefs().size())) { auto it = input_indexes.find(node->InputDefs()[dst_idx]->Name()); if (it != input_indexes.cend()) { AddEdge(producer_idx, new_node_idx, src_idx, it->second); } } else { - int dst_implicit_input_idx = dst_idx - (int)node->InputDefs().size(); + int dst_implicit_input_idx = dst_idx - static_cast(node->InputDefs().size()); ORT_ENFORCE(dst_implicit_input_idx < (int)node->ImplicitInputDefs().size()); auto it = input_indexes.find(node->ImplicitInputDefs()[dst_implicit_input_idx]->Name()); if (it != input_indexes.cend()) { @@ -5012,7 +5032,8 @@ Status Graph::InlineFunction(Node& callnode) { // Remove output edges. Requirement for RemoveNode() below. auto output_edges = callnode.GetRelationships().output_edges; // copy so RemoveEdge doesn't invalidate iterator for (const auto& output_edge : output_edges) { - RemoveEdge(callnode.Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), output_edge.GetDstArgIndex()); + RemoveEdge(callnode.Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), + output_edge.GetDstArgIndex()); } // create a uniq_identifier to append to every node name and intermediate input\outputs diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 13620f4d8b3b..221bc01f5d15 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -2,12 +2,14 @@ // Licensed under the MIT License. #include "core/graph/graph_utils.h" - -#include - #include "core/graph/graph.h" #include "core/common/logging/logging.h" +#include +#include +#include +#include + namespace onnxruntime { namespace graph_utils { @@ -56,7 +58,8 @@ static bool CanUpdateImplicitInputNameInSubgraph(const Node& node, } for (auto& subgraph_node : subgraph->Nodes()) { - // recurse if this node also consumes removed_output_name as an implicit input (i.e. there are multiple levels of nested + // recurse if this node also consumes removed_output_name as an implicit input (i.e. there are multiple levels + // of nested // subgraphs, and at least one level lower uses removed_output_name as an implicit input const auto subgraph_node_implicit_inputs = subgraph_node.ImplicitInputDefs(); if (!subgraph_node_implicit_inputs.empty()) { @@ -464,13 +467,13 @@ static bool IsOnlyOneOutputUsed(const Graph& graph, const Node& node, const std: // a) there's only 1, and b) it's the same as any output consumed by another node auto output_indexes = graph.GetNodeOutputsInGraphOutputs(node); auto num_graph_outputs = output_indexes.size(); - if (num_graph_outputs > 1) + if (num_graph_outputs > 1) { return false; - else if (num_graph_outputs == 1) { - if (first_output != unassigned) + } else if (num_graph_outputs == 1) { + if (first_output != unassigned) { // an output is consumed by other nodes, so make sure the same output is providing the graph output return output_indexes.front() == first_output; - else { + } else { // graph output only as no other nodes are consuming the output, so just update the output_name output_name = &node.OutputDefs()[output_indexes.front()]->Name(); } @@ -678,7 +681,8 @@ const Node* FirstParentByType(const Node& node, const std::string& parent_type) return nullptr; } -void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node& replacement, int replacement_output_idx) { +void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node& replacement, + int replacement_output_idx) { // get the output edges from node for output_idx std::vector output_edges = GraphEdge::GetNodeOutputEdges(node, output_idx); @@ -726,7 +730,9 @@ void AddNodeInput(Node& target, int target_input_idx, NodeArg& new_input) { "Can only add a new input at the end of the current ones."); target.MutableInputDefs().push_back(&new_input); - assert(target.MutableInputArgsCount().size() > static_cast(target_input_idx)); // expect existing entry for all possible inputs + + // expect existing entry for all possible inputs + assert(target.MutableInputArgsCount().size() > static_cast(target_input_idx)); target.MutableInputArgsCount()[target_input_idx] = 1; } @@ -798,7 +804,8 @@ bool FindPath(const Node& node, bool is_input_edge, gsl::spanOpType() << "->" << edge.op_type; + LOGS(logger, WARNING) << "Failed since multiple edges matched:" << current_node->OpType() << "->" + << edge.op_type; return false; } edge_found = &(*it); @@ -821,7 +828,8 @@ bool FindPath(const Node& node, bool is_input_edge, gsl::span edges_to_match, std::vector>& result, const logging::Logger& logger) { +bool FindPath(Graph& graph, const Node& node, bool is_input_edge, gsl::span edges_to_match, + std::vector>& result, const logging::Logger& logger) { result.clear(); std::vector edge_ends; @@ -830,9 +838,10 @@ bool FindPath(Graph& graph, const Node& node, bool is_input_edge, gsl::span Node& { - return *graph.GetNode(edge_end->GetNode().Index()); - }); + std::transform(edge_ends.begin(), edge_ends.end(), std::back_inserter(result), + [&graph](const Node::EdgeEnd* edge_end) -> Node& { + return *graph.GetNode(edge_end->GetNode().Index()); + }); return true; } diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index 319e055200cc..0b713196203d 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -8,6 +8,9 @@ #include "core/graph/onnx_protobuf.h" #include "core/graph/graph.h" +#include +#include + namespace onnxruntime { namespace graph_utils { @@ -247,7 +250,8 @@ e.g. Node A produces outputs A1 and A2. to replace B1 (output index 0 for node B) with A2 (output index 1 for node A) as input to the downstream node C. The edge that existed between B and C for B1 will be removed, and replaced with an edge between A and C for A2. */ -void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node& replacement, int replacement_output_idx); +void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node& replacement, + int replacement_output_idx); /** Replace the input to a node with a NodeArg. @remarks The replacement only updates the node's input definition and does not create any edges, @@ -302,11 +306,13 @@ inline void FinalizeNodeFusion(Graph& graph, The output definitions and edges from the last node in 'nodes' will be moved to replacement_node. All nodes in 'nodes' will be removed. */ -inline void FinalizeNodeFusion(Graph& graph, gsl::span> nodes, Node& replacement_node) { +inline void FinalizeNodeFusion(Graph& graph, gsl::span> nodes, + Node& replacement_node) { FinalizeNodeFusion(graph, nodes, replacement_node, replacement_node); } -inline void FinalizeNodeFusion(Graph& graph, std::initializer_list> nodes, Node& replacement_node) { +inline void FinalizeNodeFusion(Graph& graph, std::initializer_list> nodes, + Node& replacement_node) { FinalizeNodeFusion(graph, AsSpan(nodes), replacement_node, replacement_node); } @@ -357,17 +363,23 @@ struct EdgeEndToMatch { It is recommended to match path from bottom to top direction to avoid such issue. It is because each node input (dst_arg_index) only accepts one input edge. */ -bool FindPath(const Node& node, bool is_input_edge, gsl::span edges_to_match, std::vector& result, const logging::Logger& logger); +bool FindPath(const Node& node, bool is_input_edge, gsl::span edges_to_match, + std::vector& result, const logging::Logger& logger); -inline bool FindPath(const Node& node, bool is_input_edge, std::initializer_list edges_to_match, std::vector& result, const logging::Logger& logger) { +inline bool FindPath(const Node& node, bool is_input_edge, std::initializer_list edges_to_match, + std::vector& result, const logging::Logger& logger) { return FindPath(node, is_input_edge, AsSpan(edges_to_match), result, logger); } /** Same as FindPath above, but return the references of matched Node */ -bool FindPath(Graph& graph, const Node& node, bool is_input_edge, gsl::span edges_to_match, std::vector>& result, const logging::Logger& logger); +bool FindPath(Graph& graph, const Node& node, bool is_input_edge, gsl::span edges_to_match, + std::vector>& result, const logging::Logger& logger); -inline bool FindPath(Graph& graph, const Node& node, bool is_input_edge, std::initializer_list edges_to_match, std::vector>& result, const logging::Logger& logger) { +inline bool FindPath(Graph& graph, const Node& node, bool is_input_edge, + std::initializer_list edges_to_match, + std::vector>& result, + const logging::Logger& logger) { return FindPath(graph, node, is_input_edge, AsSpan(edges_to_match), result, logger); } diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index c639eeac5ea4..1842c2b4a0d1 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -332,7 +332,7 @@ const std::vector& GraphViewer::GetNodesInTopologicalOrder(ExecutionO } const std::vector& GraphViewer::GetRootNodes() const { - // TODO: See if we need to calculate the root_nodes_ of the filtered graph. + // TODO(somebody): See if we need to calculate the root_nodes_ of the filtered graph. // GetRootNodes is only used by parallel executor currently, and isn't relevant to the usage of a filtered graph. ORT_ENFORCE(filter_info_ == nullptr, "Not supported with filtered graph."); diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index e9d1b4e944ed..ee4d9f915497 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -20,7 +20,7 @@ #endif #include "core/util/protobuf_parsing_utils.h" -#include "core/common/gsl.h" +#include #include "core/platform/env.h" @@ -140,12 +140,13 @@ Model::Model(const std::string& graph_name, auto func_template_ptr = std::make_unique(); func_template_ptr->op_schema_ = std::move(func_schema_ptr); func_template_ptr->onnx_func_proto_ = &func; - model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), + model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), + func.name()), std::move(func_template_ptr)); } // need to call private ctor so can't use make_shared - GSL_SUPPRESS(r.11) + GSL_SUPPRESS(r .11) graph_.reset(new Graph(*this, model_proto_.mutable_graph(), *p_domain_to_version, IrVersion(), schema_registry, logger, options.strict_shape_type_inference)); } @@ -269,11 +270,13 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, auto func_template_ptr = std::make_unique(); func_template_ptr->op_schema_ = std::move(func_schema_ptr); func_template_ptr->onnx_func_proto_ = &func; - model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), std::move(func_template_ptr)); + model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), + func.name()), + std::move(func_template_ptr)); } // create instance. need to call private ctor so can't use make_unique - GSL_SUPPRESS(r.11) + GSL_SUPPRESS(r .11) graph_.reset(new Graph(*this, model_proto_.mutable_graph(), domain_to_version, IrVersion(), schema_registry, logger, options.strict_shape_type_inference)); } @@ -425,7 +428,7 @@ Status Model::Load(const ModelProto& model_proto, } // need to call private ctor so can't use make_shared - GSL_SUPPRESS(r.11) + GSL_SUPPRESS(r .11) auto status = Status::OK(); ORT_TRY { @@ -465,7 +468,7 @@ Status Model::Load(ModelProto&& model_proto, } // need to call private ctor so can't use make_shared - GSL_SUPPRESS(r.11) + GSL_SUPPRESS(r .11) auto status = Status::OK(); ORT_TRY { model = std::make_unique(std::move(model_proto), model_path, local_registries, logger, options); @@ -512,7 +515,7 @@ static Status LoadModelHelper(const T& file_path, Loader loader) { } if (!status.IsOK()) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd)); return status; } @@ -554,7 +557,8 @@ static Status SaveModel(Model& model, const T& file_path) { const file_path = UTF8ToString($2); const bytes = new Uint8Array(buffer_size); bytes.set(HEAPU8.subarray(buffer, buffer + buffer_size)); - if (typeof process == 'object' && typeof process.versions == 'object' && typeof process.versions.node == 'string') { + if (typeof process == 'object' && typeof process.versions == 'object' && + typeof process.versions.node == 'string') { // Node.js require('fs').writeFileSync(file_path, bytes); } else { @@ -585,7 +589,7 @@ static Status SaveModel(Model& model, const T& file_path) { }); } if (!status.IsOK()) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd)); return status; } @@ -616,7 +620,7 @@ static Status SaveModelWithExternalInitializers(Model& model, }); } if (!status.IsOK()) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd)); return status; } @@ -628,8 +632,8 @@ Status Model::Load(const PathString& file_path, return LoadModel(file_path, model_proto); } -GSL_SUPPRESS(r.30) // spurious warnings. p_model is potentially reset in the internal call to Load -GSL_SUPPRESS(r.35) +GSL_SUPPRESS(r .30) // spurious warnings. p_model is potentially reset in the internal call to Load +GSL_SUPPRESS(r .35) Status Model::Load(const PathString& file_path, std::shared_ptr& p_model, const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger, const ModelOptions& options) { @@ -762,7 +766,8 @@ Status Model::SaveWithExternalInitializers(Model& model, ORT_RETURN_IF_ERROR(model.MainGraph().Resolve()); - auto model_proto = model.ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold); + auto model_proto = model.ToGraphProtoWithExternalInitializers(external_file_name, file_path, + initializer_size_threshold); google::protobuf::io::FileOutputStream output(fd); const bool result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush(); if (result) { diff --git a/onnxruntime/core/graph/node_attr_utils.h b/onnxruntime/core/graph/node_attr_utils.h index 9433cfabc974..638cebe6a320 100644 --- a/onnxruntime/core/graph/node_attr_utils.h +++ b/onnxruntime/core/graph/node_attr_utils.h @@ -5,7 +5,7 @@ #include -#include "core/common/gsl.h" +#include #include "core/graph/onnx_protobuf.h" #include "core/graph/basic_types.h" diff --git a/onnxruntime/core/graph/op.h b/onnxruntime/core/graph/op.h index 2a720a3ccad4..295b8cb95f48 100644 --- a/onnxruntime/core/graph/op.h +++ b/onnxruntime/core/graph/op.h @@ -3,12 +3,13 @@ #pragma once -#include -#include #include "core/graph/onnx_protobuf.h" #include "core/common/status.h" #include "core/graph/constants.h" +#include +#include + namespace onnxruntime { using AttrType = ONNX_NAMESPACE::AttributeProto_AttributeType; using NodeAttributes = std::unordered_map; @@ -29,19 +30,18 @@ AttributeProto_AttributeType_GRAPHS = 10, AttributeProto_AttributeType_SPARSE_TENSOR = 22, AttributeProto_AttributeType_SPARSE_TENSORS = 23, */ -static constexpr const char* kAttrTypeStrings[] = - { - "UNDEFINED", - "FLOAT", - "INT", - "STRING", - "TENSOR", - "GRAPH", - "FLOATS", - "INTS", - "STRINGS", - "TENSORS", - "GRAPHS"}; +static constexpr const char* kAttrTypeStrings[] = { + "UNDEFINED", + "FLOAT", + "INT", + "STRING", + "TENSOR", + "GRAPH", + "FLOATS", + "INTS", + "STRINGS", + "TENSORS", + "GRAPHS"}; class TypeUtils { public: diff --git a/onnxruntime/core/graph/runtime_optimization_record_container.cc b/onnxruntime/core/graph/runtime_optimization_record_container.cc index acd85b909e5b..2d0e1076ee37 100644 --- a/onnxruntime/core/graph/runtime_optimization_record_container.cc +++ b/onnxruntime/core/graph/runtime_optimization_record_container.cc @@ -4,15 +4,13 @@ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) #include "core/graph/runtime_optimization_record_container.h" - -#include - -#include "core/common/gsl.h" - #include "core/flatbuffers/flatbuffers_utils.h" #include "core/flatbuffers/schema/ort.fbs.h" #include "core/graph/op_identifier_utils.h" +#include +#include + namespace onnxruntime { #if !defined(ORT_MINIMAL_BUILD) @@ -182,7 +180,8 @@ Status RuntimeOptimizationRecordContainer::LoadFromOrtFormat( } ORT_RETURN_IF_NOT(optimizer_name_to_records.emplace(optimizer_name, std::move(records)).second, - "Attempting to load runtime optimization records for a previously loaded optimizer: ", optimizer_name); + "Attempting to load runtime optimization records for a previously loaded optimizer: ", + optimizer_name); } optimizer_name_to_records_ = std::move(optimizer_name_to_records); diff --git a/onnxruntime/core/graph/schema_registry.cc b/onnxruntime/core/graph/schema_registry.cc index 4dc714bd8af7..a7d94f4571d9 100644 --- a/onnxruntime/core/graph/schema_registry.cc +++ b/onnxruntime/core/graph/schema_registry.cc @@ -99,7 +99,7 @@ common::Status OnnxRuntimeOpSchemaRegistry::RegisterOpSchemaInternal(ONNX_NAMESP << "than the operator set version " << ver_range_it->second.opset_version << std::endl; return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostream.str()); } - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) map_[op_name][op_domain].emplace(std::make_pair(ver, op_schema)); return common::Status::OK(); } @@ -161,7 +161,8 @@ void SchemaRegistryManager::RegisterRegistry(std::shared_ptrGetLatestOpsetVersions(is_onnx_only); @@ -172,7 +173,7 @@ void SchemaRegistryManager::GetDomainToVersionMapForRegistries(DomainToVersionMa // If the map doesn't yet contain this domain, insert it with this registry's value. // Otherwise, merge the existing range in the map. if (iter == domain_version_map.end()) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) domain_version_map.insert(local_domain); } else { iter->second = std::max(iter->second, local_domain.second); @@ -194,7 +195,7 @@ DomainToVersionMap SchemaRegistryManager::GetLastReleasedOpsetVersions(bool is_o continue; auto it = domain_version_map.find(domain.first); if (it == domain_version_map.end()) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) domain_version_map.insert(std::make_pair(domain.first, domain.second)); } else { it->second = std::max(it->second, domain.second); @@ -217,7 +218,7 @@ DomainToVersionMap SchemaRegistryManager::GetLatestOpsetVersions(bool is_onnx_on continue; auto it = domain_version_map.find(domain.first); if (it == domain_version_map.end()) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) domain_version_map.insert(std::make_pair(domain.first, domain.second.second)); } else { it->second = std::max(it->second, domain.second.second); @@ -271,7 +272,7 @@ void SchemaRegistryManager::GetSchemaAndHistory( } if (new_version < version) { - GSL_SUPPRESS(es.84) + GSL_SUPPRESS(es .84) unchecked_registry_indices.insert(unchecked_registry_indices.end(), checked_registry_indices.begin(), checked_registry_indices.end()); diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h index 38795291b032..9c7e00fff88e 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h @@ -35,7 +35,7 @@ * * @file quantb_gemm.h * @brief Modified from cutlass/gemm/device/gemm.h, boilerplate code passing input pointers to the kernel. -*/ + */ #pragma once @@ -145,7 +145,6 @@ template < typename PermuteDLayout = layout::NoPermute> class QuantBGemm { public: - using ElementA = ElementA_; using LayoutA = LayoutA_; using TensorRefA = TensorRef; @@ -189,34 +188,33 @@ class QuantBGemm { /// Define the kernel using GemmKernel = typename kernel::DefaultQuantBGemm< - ElementA, - LayoutA, - kAlignmentA, - ElementB, - LayoutB, - kAlignmentB, - ElementQScale, - ElementQOffset, - LayoutQMeta, - QuantBlocking, - ElementC, - LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - kStages, - kSplitKSerial, - Operator, - GatherA, - GatherB, - ScatterD, - PermuteDLayout - >::GemmKernel; + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementQScale, + ElementQOffset, + LayoutQMeta, + QuantBlocking, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + GatherA, + GatherB, + ScatterD, + PermuteDLayout>::GemmKernel; /// Argument structure struct Arguments { @@ -237,9 +235,9 @@ class QuantBGemm { // split-K parallelism (etc.) are not yet supported, keeping this for future extension int split_k_slices{1}; // For gather+scatter operations - int const *gather_A_indices{nullptr}; - int const *gather_B_indices{nullptr}; - int const *scatter_D_indices{nullptr}; + int const* gather_A_indices{nullptr}; + int const* gather_B_indices{nullptr}; + int const* scatter_D_indices{nullptr}; // // Methods @@ -247,49 +245,47 @@ class QuantBGemm { /// Default ctor CUTLASS_HOST_DEVICE - Arguments(): problem_size(0, 0, 0) {} + Arguments() : problem_size(0, 0, 0) {} /// Constructs an Arguments structure CUTLASS_HOST_DEVICE Arguments( - GemmCoord problem_size_, - TensorRef ref_A_, - TensorRef ref_B_, - TensorRef ref_Qscale_, - TensorRef ref_C_, - TensorRef ref_D_, - typename EpilogueOutputOp::Params epilogue_ = - typename EpilogueOutputOp::Params()): - problem_size(problem_size_), - ref_A(ref_A_), - ref_B(ref_B_), - ref_Qscale(ref_Qscale_), - ref_C(ref_C_), - ref_D(ref_D_), - epilogue(epilogue_) { - assert(!kHasQOffset); + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params()) : problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(!kHasQOffset); } CUTLASS_HOST_DEVICE Arguments( - GemmCoord problem_size_, - TensorRef ref_A_, - TensorRef ref_B_, - TensorRef ref_Qscale_, - TensorRef ref_Qoffset_, - TensorRef ref_C_, - TensorRef ref_D_, - typename EpilogueOutputOp::Params epilogue_ = - typename EpilogueOutputOp::Params()): - problem_size(problem_size_), - ref_A(ref_A_), - ref_B(ref_B_), - ref_Qscale(ref_Qscale_), - ref_Qoffset(ref_Qoffset_), - ref_C(ref_C_), - ref_D(ref_D_), - epilogue(epilogue_) { - assert(kHasQOffset); + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_Qoffset_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params()) : problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_Qoffset(ref_Qoffset_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(kHasQOffset); } }; @@ -299,24 +295,22 @@ class QuantBGemm { public: /// Constructs the GEMM. - QuantBGemm() { } + QuantBGemm() {} /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const &args) { - + static Status can_implement(Arguments const& args) { if (!kSplitKSerial && args.split_k_slices > 1) { return Status::kErrorInvalidProblem; } Status status = GemmKernel::can_implement( - args.problem_size, - args.ref_A.non_const_ref(), - args.ref_B.non_const_ref(), - args.ref_Qscale.non_const_ref(), - args.ref_Qoffset.non_const_ref(), - args.ref_C.non_const_ref(), - args.ref_D - ); + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D); if (status != Status::kSuccess) { return status; @@ -326,20 +320,18 @@ class QuantBGemm { } /// Gets the workspace size - static size_t get_workspace_size(Arguments const &args) { - + static size_t get_workspace_size(Arguments const& args) { size_t bytes = 0; // Determine grid shape ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.split_k_slices); + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); if (kSplitKSerial && args.split_k_slices > 1) { - bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); } @@ -347,15 +339,14 @@ class QuantBGemm { } /// Initializes GEMM state from arguments. - Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { // Determine grid shape ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.split_k_slices); + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); if (kSplitKSerial) { if (args.split_k_slices > 1) { @@ -372,7 +363,6 @@ class QuantBGemm { } } } else { - if (args.split_k_slices > 1) { return Status::kErrorInvalidProblem; } @@ -380,27 +370,25 @@ class QuantBGemm { // Initialize the Params structure params_ = typename GemmKernel::Params{ - args.problem_size, - grid_shape, - args.ref_A.non_const_ref(), - args.ref_B.non_const_ref(), - args.ref_Qscale.non_const_ref(), - args.ref_Qoffset.non_const_ref(), - args.ref_C.non_const_ref(), - args.ref_D, - args.epilogue, - static_cast(workspace), - args.gather_A_indices, - args.gather_B_indices, - args.scatter_D_indices - }; + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.epilogue, + static_cast(workspace), + args.gather_A_indices, + args.gather_B_indices, + args.scatter_D_indices}; return Status::kSuccess; } /// Lightweight update given a subset of arguments - Status update(Arguments const &args, void *workspace = nullptr) { - + Status update(Arguments const& args, void* workspace = nullptr) { if (kSplitKSerial && args.split_k_slices > 1) { if (!workspace) { return Status::kErrorWorkspaceNull; @@ -414,14 +402,13 @@ class QuantBGemm { params_.ref_C.reset(args.ref_C.non_const_ref().data()); params_.ref_D.reset(args.ref_D.data()); params_.output_op = args.epilogue; - params_.semaphore = static_cast(workspace); + params_.semaphore = static_cast(workspace); return Status::kSuccess; } /// Runs the kernel using initialized state. Status run(cudaStream_t stream = nullptr) { - ThreadblockSwizzle threadblock_swizzle; dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); @@ -457,10 +444,9 @@ class QuantBGemm { /// Runs the kernel using initialized state. Status operator()( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr) { - + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr) { Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { @@ -471,11 +457,10 @@ class QuantBGemm { } }; - //////////////////////////////////////////////////////////////////////////////// -} // namespace device -} // namespace gemm -} // namespace cutlass +} // namespace device +} // namespace gemm +} // namespace cutlass //////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h index 2f4460bb59e9..f471411730a3 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h @@ -69,7 +69,7 @@ #if defined(CUTLASS_ARCH_WMMA_ENABLED) #include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -#endif //CUTLASS_ARCH_WMMA_ENABLED +#endif // CUTLASS_ARCH_WMMA_ENABLED //////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -139,13 +139,11 @@ template < /// Permute operand B typename PermuteBLayout = layout::NoPermute, /// - typename Enable = void -> + typename Enable = void> struct DefaultQuantBGemm; //////////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////////// /// Partial specialization for Ampere Architecture @@ -204,8 +202,7 @@ template < /// Permute operand A typename PermuteALayout, /// Permute operand B - typename PermuteBLayout -> + typename PermuteBLayout> struct DefaultQuantBGemm { - - static_assert((platform::is_same::value - || platform::is_same>::value), - "Epilogue in the kernel level must be row major"); + static_assert((platform::is_same::value || platform::is_same>::value), + "Epilogue in the kernel level must be row major"); /// Define the threadblock-scoped matrix multiply-accumulate using Mma = typename cutlass::gemm::threadblock::DefaultQuantBMma< diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h index 6e5ad8f40614..72b2cf641e13 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h @@ -59,13 +59,12 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. -> + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. + > struct QuantBGemm { - using Mma = Mma_; using Epilogue = Epilogue_; using OutputOp = typename Epilogue::OutputOp; @@ -96,55 +95,53 @@ struct QuantBGemm { typename Epilogue::OutputTileIterator::Params params_D; typename Epilogue::OutputTileIterator::TensorRef ref_D; typename OutputOp::Params output_op; - int *semaphore; + int* semaphore; int gemm_k_size; // how many k vectors are processed by this threadblock // For gather+scatter operations - int const *gather_A_indices; - int const *gather_B_indices; - int const *scatter_D_indices; + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; // // Methods // CUTLASS_HOST_DEVICE - Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } + Params() : swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {} CUTLASS_HOST_DEVICE Params( - cutlass::gemm::GemmCoord const & problem_size, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - typename Mma::IteratorA::TensorRef ref_A, - typename Mma::IteratorB::TensorRef ref_B, - typename Mma::IteratorQScale::TensorRef ref_QScale, - typename Mma::IteratorQOffset::TensorRef ref_QOffset, - typename Epilogue::OutputTileIterator::TensorRef ref_C, - typename Epilogue::OutputTileIterator::TensorRef ref_D, - typename OutputOp::Params output_op = typename OutputOp::Params(), - int *workspace = nullptr, - int const *gather_A_indices = nullptr, - int const *gather_B_indices = nullptr, - int const *scatter_D_indices = nullptr - ): - problem_size(problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), - params_A(ref_A.layout()), - ref_A(ref_A), - params_B(ref_B.layout()), - ref_B(ref_B), - params_QScale(ref_QScale.layout()), - ref_QScale(ref_QScale), - params_QOffset(ref_QOffset.layout()), - ref_QOffset(ref_QOffset), - params_C(ref_C.layout()), - ref_C(ref_C), - params_D(ref_D.layout()), - ref_D(ref_D), - output_op(output_op), - gather_A_indices(gather_A_indices), - gather_B_indices(gather_B_indices), - scatter_D_indices(scatter_D_indices) { + cutlass::gemm::GemmCoord const& problem_size, + cutlass::gemm::GemmCoord const& grid_tiled_shape, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + typename OutputOp::Params output_op = typename OutputOp::Params(), + int* workspace = nullptr, + int const* gather_A_indices = nullptr, + int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) : problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(ref_A.layout()), + ref_A(ref_A), + params_B(ref_B.layout()), + ref_B(ref_B), + params_QScale(ref_QScale.layout()), + ref_QScale(ref_QScale), + params_QOffset(ref_QOffset.layout()), + ref_QOffset(ref_QOffset), + params_C(ref_C.layout()), + ref_C(ref_C), + params_D(ref_D.layout()), + ref_D(ref_D), + output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) { int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); @@ -165,42 +162,41 @@ struct QuantBGemm { // CUTLASS_HOST_DEVICE - QuantBGemm() { } + QuantBGemm() {} /// Determines whether kernel satisfies alignment CUTLASS_HOST_DEVICE static Status can_implement( - cutlass::gemm::GemmCoord const & problem_size, - typename Mma::IteratorA::TensorRef ref_A, - typename Mma::IteratorB::TensorRef ref_B, - typename Mma::IteratorQScale::TensorRef ref_QScale, - typename Mma::IteratorQOffset::TensorRef ref_QOffset, - typename Epilogue::OutputTileIterator::TensorRef ref_C, - typename Epilogue::OutputTileIterator::TensorRef ref_D) { - + cutlass::gemm::GemmCoord const& problem_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D) { // TODO check problem_size K, N must be multiple of QuantBlocking static int const kAlignmentA = (platform::is_same>::value) - ? 32 + ? 32 : (platform::is_same>::value) - ? 64 - : Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = (platform::is_same>::value) - ? 32 + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (platform::is_same>::value) + ? 32 : (platform::is_same>::value) - ? 64 - : Mma::IteratorB::AccessType::kElements; + ? 64 + : Mma::IteratorB::AccessType::kElements; static int const kAlignmentC = (platform::is_same>::value) - ? 32 + ? 32 : (platform::is_same>::value) - ? 64 - : Epilogue::OutputTileIterator::kElementsPerAccess; + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; if (!TensorRef_aligned(ref_A, kAlignmentA)) { return Status::kErrorMisalignedOperand; @@ -237,7 +233,7 @@ struct QuantBGemm { return Status::kErrorMisalignedOperand; } - if constexpr(kHasQOffset) { + if constexpr (kHasQOffset) { if (!TensorRef_aligned(ref_QOffset, Mma::IteratorQOffset::AccessType::kElements)) { return Status::kErrorMisalignedOperand; } @@ -256,8 +252,7 @@ struct QuantBGemm { /// Executes one GEMM CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - + void operator()(Params const& params, SharedStorage& shared_storage) { // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; @@ -266,26 +261,24 @@ struct QuantBGemm { // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || - params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { - + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { return; } // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size, + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, }; cutlass::MatrixCoord tb_offset_B{ - (threadblock_tile_offset.k() * params.gemm_k_size) / 2, - (threadblock_tile_offset.n() * Mma::Shape::kN) / 2 - }; + (threadblock_tile_offset.k() * params.gemm_k_size) / 2, + (threadblock_tile_offset.n() * Mma::Shape::kN) / 2}; // Problem size is a function of threadblock index in the K dimension int problem_size_k = min( - params.problem_size.k(), - (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); // Compute threadblock-scoped matrix multiply-add int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; @@ -295,20 +288,20 @@ struct QuantBGemm { // Construct iterators to A and B operands typename Mma::IteratorA iterator_A( - params.params_A, - params.ref_A.data(), - {params.problem_size.m(), problem_size_k}, - thread_idx, - tb_offset_A, - params.gather_A_indices); + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.gather_A_indices); typename Mma::IteratorB iterator_B( - params.params_B, - params.ref_B.data(), - {problem_size_k/2, params.problem_size.n()/2}, - thread_idx, - tb_offset_B, - params.gather_B_indices); + params.params_B, + params.ref_B.data(), + {problem_size_k / 2, params.problem_size.n() / 2}, + thread_idx, + tb_offset_B, + params.gather_B_indices); const int qscale_k = problem_size_k / Mma::QuantBlocking::kRow; const int qscale_n = params.problem_size.n() / Mma::QuantBlocking::kColumn; @@ -318,24 +311,23 @@ struct QuantBGemm { assert((qscale_n > 0) && (qscale_n * Mma::QuantBlocking::kColumn == params.problem_size.n())); cutlass::MatrixCoord tb_offset_QScale{ - threadblock_tile_offset.k() * (params.gemm_k_size/Mma::QuantBlocking::kRow), - threadblock_tile_offset.n() * (Mma::Shape::kN/Mma::QuantBlocking::kColumn) - }; + threadblock_tile_offset.k() * (params.gemm_k_size / Mma::QuantBlocking::kRow), + threadblock_tile_offset.n() * (Mma::Shape::kN / Mma::QuantBlocking::kColumn)}; typename Mma::IteratorQScale iterator_QScale( - params.params_QScale, - params.ref_QScale.data(), - {qscale_k, qscale_n}, - thread_idx, - tb_offset_QScale, - nullptr); + params.params_QScale, + params.ref_QScale.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale, + nullptr); typename Mma::IteratorQOffset iterator_QOffset( - params.params_QOffset, - params.ref_QOffset.data(), - {qscale_k, qscale_n}, - thread_idx, - tb_offset_QScale); + params.params_QOffset, + params.ref_QOffset.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. @@ -371,11 +363,10 @@ struct QuantBGemm { threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - //assume identity swizzle + // assume identity swizzle MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.n() * Mma::Shape::kN - ); + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); @@ -384,7 +375,6 @@ struct QuantBGemm { // If performing a reduction via split-K, fetch the initial synchronization if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - // Fetch the synchronization lock initially but do not block. semaphore.fetch(); @@ -394,40 +384,36 @@ struct QuantBGemm { // Tile iterator loading from source tensor. typename Epilogue::OutputTileIterator iterator_C( - params.params_C, - params.ref_C.data(), - params.problem_size.mn(), - thread_idx, - threadblock_offset, - params.scatter_D_indices - ); + params.params_C, + params.ref_C.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices); // Tile iterator writing to destination tensor. typename Epilogue::OutputTileIterator iterator_D( - params.params_D, - params.ref_D.data(), - params.problem_size.mn(), - thread_idx, - threadblock_offset, - params.scatter_D_indices - ); + params.params_D, + params.ref_D.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices); Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); // Wait on the semaphore - this latency may have been covered by iterator construction if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. if (threadblock_tile_offset.k()) { iterator_C = iterator_D; } semaphore.wait(threadblock_tile_offset.k()); - } // Execute the epilogue operator to update the destination tensor. @@ -438,14 +424,11 @@ struct QuantBGemm { // if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - int lock = 0; if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { - // The final threadblock resets the semaphore for subsequent grids. lock = 0; - } - else { + } else { // Otherwise, the semaphore is incremented lock = threadblock_tile_offset.k() + 1; } @@ -457,6 +440,6 @@ struct QuantBGemm { ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h index 0af604f090e1..02b439fbf59e 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h @@ -113,8 +113,7 @@ template < /// Permute operand A typename PermuteALayout = layout::NoPermute, /// Permute operand B - typename PermuteBLayout = layout::NoPermute - > + typename PermuteBLayout = layout::NoPermute> struct DefaultQuantBMma; //////////////////////////////////////////////////////////////////////////////// @@ -164,19 +163,16 @@ template < /// Permute operand A typename PermuteALayout, /// Permute operand B - typename PermuteBLayout - > + typename PermuteBLayout> struct DefaultQuantBMma { - - static_assert(platform::is_same::value - || platform::is_same>::value, - "simt epilogue must be row major"); + kAlignmentB, ElementQScale, ElementQOffset, + LayoutQMeta, QuantBlocking, + ElementAccumulator, LayoutC, + arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, + InstructionShape, Stages, Operator, false, + GatherA, GatherB, PermuteALayout, PermuteBLayout> { + static_assert(platform::is_same::value || platform::is_same>::value, + "simt epilogue must be row major"); static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) @@ -208,7 +204,7 @@ struct DefaultQuantBMma; using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, + cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB, PermuteBLayout>; // Define iterators over tiles from the quant scales @@ -225,8 +221,8 @@ struct DefaultQuantBMma; using IteratorQOffset = cutlass::transform::threadblock::OptionalPredicatedTileAccessIterator< - typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQMeta, - 0, ThreadMapQOffset, AccessTypeQOffset, MmaCore::kThreads>; + typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQMeta, + 0, ThreadMapQOffset, AccessTypeQOffset, MmaCore::kThreads>; // Define the threadblock-scoped multistage matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::QuantBMmaMultistage< @@ -241,8 +237,8 @@ struct DefaultQuantBMma::value || is_complex::value) -> + bool IsComplex = false // (is_complex::value || is_complex::value) + > struct DefaultQuantBMmaCore; //////////////////////////////////////////////////////////////////////////////// @@ -173,10 +173,10 @@ template < /// Cache operation of operand B cutlass::arch::CacheOperation::Kind CacheOpB> struct DefaultQuantBMmaCore { + layout::RowMajor, ElementB_, layout::ColumnMajor, + ElementQScale_, ElementQOffset_, LayoutQMeta_, QuantBlocking_, + ElementC_, LayoutC_, arch::OpClassTensorOp, Stages, + Operator_, false, CacheOpA, CacheOpB> { using Shape = Shape_; using WarpShape = WarpShape_; using InstructionShape = InstructionShape_; @@ -239,7 +239,7 @@ struct DefaultQuantBMmaCore::value, Shape::kK>; using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK/2>; + sizeof_bits::value, Shape::kK / 2>; // // Iterators to write to shared memory @@ -259,14 +259,14 @@ struct DefaultQuantBMmaCore, kThreads, + layout::PitchLinearShape, kThreads, layout::PitchLinearShape, kAccessSizeInBits / sizeof_bits::value>; /// Shared memory iterator to B operand using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, + MatrixShape, ElementB, SmemLayoutB, 1, IteratorThreadMapB>; using SmemLayoutQScale = LayoutQMeta; @@ -278,9 +278,9 @@ struct DefaultQuantBMmaCore 0, "QuantBlocking too big to fit in a thread block!"); static_assert(QuantBlocking::kRow == 1 || QuantBlocking::kColumn == 1, - "Only support single column or row quantize blocking!"); + "Only support single column or row quantize blocking!"); static_assert(QuantBlocking::kColumn != 1 || std::is_same::value, - "Quant scale matrix's major dimension must have more elements, to facilitate fast loading!"); + "Quant scale matrix's major dimension must have more elements, to facilitate fast loading!"); /// Threadblock-level quantization meta data shape in pitch-linear layout using TBQPitchLinearShape = typename std::conditional< @@ -303,7 +303,7 @@ struct DefaultQuantBMmaCore; using SmemIteratorQScale = transform::threadblock::RegularTileAccessIterator< - ThreadblockQShape, ElementQScale, SmemLayoutQScale, 1, IteratorThreadMapQScale>; + ThreadblockQShape, ElementQScale, SmemLayoutQScale, 1, IteratorThreadMapQScale>; static int const kElementsPerAccessQOffset = (kAccessSizeInBits / sizeof_bits::value) > TBQPitchLinearShape::kContiguous @@ -316,7 +316,7 @@ struct DefaultQuantBMmaCore; using SmemIteratorQOffset = transform::threadblock::OptionalRegularTileAccessIterator< - ThreadblockQShape, ElementQOffset, SmemLayoutQOffset, 1, IteratorThreadMapQOffset, kThreads>; + ThreadblockQShape, ElementQOffset, SmemLayoutQOffset, 1, IteratorThreadMapQOffset, kThreads>; // // Warp-level matrix multiply operator @@ -330,7 +330,7 @@ struct DefaultQuantBMmaCore, - MatrixShape<0, 0>, WarpCount::kK>; + MatrixShape<0, 0>, WarpCount::kK>; }; //////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h index 6f27a692a3a2..1e02e1264e09 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h @@ -22,7 +22,6 @@ namespace cutlass { namespace transform { namespace threadblock { - //////////////////////////////////////////////////////////////////////////////// /// Optional 2-D matrix data loader, when element is std::monostate, the @@ -43,9 +42,8 @@ template < /// Number of threads in the threadblock, when provided, the iterator /// will utilize the higher numbered threads int kThreadBlockSize_ = -1> -class OptionalPredicatedTileAccessIterator{ +class OptionalPredicatedTileAccessIterator { public: - using Shape = Shape_; using Element = Element_; using Layout = Layout_; @@ -56,9 +54,9 @@ class OptionalPredicatedTileAccessIterator{ static constexpr int kThreadblockSize = kThreadBlockSize_; static_assert(!std::is_same::value, - "Disabled Iterator failed to match the specialized version below."); + "Disabled Iterator failed to match the specialized version below."); static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, - "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); using Base = PredicatedTileAccessIterator; @@ -72,7 +70,7 @@ class OptionalPredicatedTileAccessIterator{ static constexpr int kAccessesPerVector = Base::kAccessesPerVector; CUTLASS_HOST_DEVICE - static int flip_thread_id(int thread_id){ + static int flip_thread_id(int thread_id) { if constexpr (kThreadblockSize > 0) { return kThreadblockSize - 1 - thread_id; } @@ -80,17 +78,17 @@ class OptionalPredicatedTileAccessIterator{ } public: - Base base_; + Base base_; /// Default constructor - OptionalPredicatedTileAccessIterator(): base_() {}; + OptionalPredicatedTileAccessIterator() : base_(){}; /// Constructs a TileIterator from its precomputed state, threadblock offset, /// and thread ID CUTLASS_HOST_DEVICE OptionalPredicatedTileAccessIterator( /// Precomputed parameters object - Params const ¶ms, + Params const& params, /// Pointer to start of tensor Pointer pointer, /// Extent of tensor @@ -98,14 +96,14 @@ class OptionalPredicatedTileAccessIterator{ /// ID of each participating thread int thread_id, /// Initial offset of threadblock - TensorCoord const &threadblock_offset) + TensorCoord const& threadblock_offset) : base_(params, pointer, extent, flip_thread_id(thread_id), threadblock_offset) {} /// Construct a PredicatedTileAccessIterator with zero threadblock offset CUTLASS_HOST_DEVICE OptionalPredicatedTileAccessIterator( /// Precomputed parameters object - Params const ¶ms, + Params const& params, /// Pointer to start of tensor Pointer pointer, /// Extent of tensor @@ -129,19 +127,19 @@ class OptionalPredicatedTileAccessIterator{ /// Advances an iterator along logical dimensions of matrix in units of whole tiles CUTLASS_DEVICE void add_tile_offset( - TensorCoord const &tile_offset) { + TensorCoord const& tile_offset) { base_.add_tile_offset(tile_offset); } /// Returns a pointer CUTLASS_HOST_DEVICE - AccessType *get() const { + AccessType* get() const { return base_.get(); } /// Increment and return an instance to self. CUTLASS_HOST_DEVICE - OptionalPredicatedTileAccessIterator &operator++() { + OptionalPredicatedTileAccessIterator& operator++() { ++base_; return *this; } @@ -168,13 +166,13 @@ class OptionalPredicatedTileAccessIterator{ /// Sets the predicate mask, overriding value stored in predicate iterator CUTLASS_HOST_DEVICE - void set_mask(Mask const &mask) { + void set_mask(Mask const& mask) { base_.set_mask(mask); } /// Gets the mask CUTLASS_HOST_DEVICE - void get_mask(Mask &mask) { + void get_mask(Mask& mask) { base_.get_mask(mask); } @@ -198,9 +196,8 @@ template < typename ThreadMap_, typename AccessType_, int kThreadBlockSize_> -class OptionalPredicatedTileAccessIterator{ +class OptionalPredicatedTileAccessIterator { public: - using Shape = Shape_; using Element = std::monostate; using Layout = Layout_; @@ -225,14 +222,14 @@ class OptionalPredicatedTileAccessIterator::value * ThreadMap_::kElementsPerAccess / 8> -class OptionalRegularTileAccessIterator{ + sizeof_bits::value* ThreadMap_::kElementsPerAccess / 8> +class OptionalRegularTileAccessIterator { public: - using Shape = Shape_; using Element = Element_; using Layout = Layout_; @@ -59,9 +58,9 @@ class OptionalRegularTileAccessIterator{ static constexpr int kThreadblockSize = ThreadblockSize_; static_assert(!std::is_same::value, - "Disabled Iterator failed to match the specialized template"); + "Disabled Iterator failed to match the specialized template"); static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, - "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); using Base = RegularTileAccessIterator; @@ -71,7 +70,7 @@ class OptionalRegularTileAccessIterator{ using AccessType = typename Base::AccessType; CUTLASS_HOST_DEVICE - static int flip_thread_id(int thread_id){ + static int flip_thread_id(int thread_id) { if constexpr (kThreadblockSize > 0) { return kThreadblockSize - 1 - thread_id; } @@ -79,15 +78,14 @@ class OptionalRegularTileAccessIterator{ } private: - Base base_; public: /// Construct a TileIterator with zero threadblock offset CUTLASS_HOST_DEVICE OptionalRegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor - int thread_id ///< ID of each participating thread - ) + int thread_id ///< ID of each participating thread + ) : base_(ref, flip_thread_id(thread_id)) {} /// Overrides the internal iteration index @@ -104,13 +102,13 @@ class OptionalRegularTileAccessIterator{ /// Returns a pointer CUTLASS_DEVICE - AccessType *get() const { + AccessType* get() const { return base_.get(); } /// Advances to the next tile in memory. CUTLASS_HOST_DEVICE - OptionalRegularTileAccessIterator &operator++() { + OptionalRegularTileAccessIterator& operator++() { ++base_; return *this; } @@ -134,7 +132,7 @@ class OptionalRegularTileAccessIterator{ /// Below two classes map col/row major to the pitch linear coordinates used /// in this base class. CUTLASS_DEVICE - void add_tile_offset(TensorCoord const &coord) { + void add_tile_offset(TensorCoord const& coord) { base_.add_tile_offset(coord); } }; @@ -151,9 +149,8 @@ template < int ThreadblockSize_, int Alignment> class OptionalRegularTileAccessIterator{ + AdvanceRank, ThreadMap_, ThreadblockSize_, Alignment> { public: - using Shape = Shape_; using Element = std::monostate; using Layout = Layout_; @@ -169,15 +166,14 @@ class OptionalRegularTileAccessIterator -struct QuantBLayoutDebug{ +struct QuantBLayoutDebug { static constexpr bool debug_smem = true; static constexpr bool debug_fragment = true; ElementWeight* smem_b_ptr_; @@ -77,15 +77,14 @@ struct QuantBLayoutDebug{ int lane_id_; int block_id_; - template - CUTLASS_DEVICE - static void print_fragment(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + template + CUTLASS_DEVICE static void print_fragment(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id) { static_assert(Size % 4 == 0, "Size must be multiple of 4"); - if constexpr (debug_fragment){ - if (block_id == 1 && warp_id == 0){ + if constexpr (debug_fragment) { + if (block_id == 1 && warp_id == 0) { const Element* ptr = reinterpret_cast(&frag); - for (int i = 0; i < Size/4; i++, ptr+=4){ - if constexpr(std::is_integral::value){ + for (int i = 0; i < Size / 4; i++, ptr += 4) { + if constexpr (std::is_integral::value) { printf("T%.2d%c%d, %3d, %3d, %3d, %3d\n", threadIdx.x, label, i, ptr[0], ptr[1], ptr[2], ptr[3]); @@ -99,21 +98,19 @@ struct QuantBLayoutDebug{ } } - template - CUTLASS_DEVICE - static void print_as_int4(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + template + CUTLASS_DEVICE static void print_as_int4(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id) { constexpr int I8Size = Size * cutlass::sizeof_bits::value / 8; static_assert(I8Size % 2 == 0, "Size must be multiple of 4"); - if constexpr (debug_fragment){ - if (block_id == 1 && warp_id == 0){ + if constexpr (debug_fragment) { + if (block_id == 1 && warp_id == 0) { const uint8_t* ptr = reinterpret_cast(&frag); - for (int i = 0; i < I8Size/2; i++, ptr+=2){ + for (int i = 0; i < I8Size / 2; i++, ptr += 2) { printf("T%.2dW%d, %d, %d, %d, %d\n", threadIdx.x, i, ptr[0] & 0x0f, ptr[0] >> 4, ptr[1] & 0x0f, ptr[1] >> 4); } } } } - }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -121,8 +118,9 @@ struct QuantBLayoutDebug{ /// Dummy type when quant offset is not used, to avoid compilation error, /// and reduce runtime footprint /// -struct DummyType{ +struct DummyType { std::monostate dummy_; + public: DummyType() = default; @@ -137,7 +135,7 @@ struct DummyType{ } }; -} +} // namespace ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -239,8 +237,9 @@ class QuantBMmaBase { Shape::kN / QuantBlocking::kColumn>; using BufTypeQOffset = std::conditional_t, - DummyType>; + AlignedBuffer, + DummyType>; + public: // // Data members @@ -259,7 +258,6 @@ class QuantBMmaBase { BufTypeQOffset operand_QOffset; public: - // // Methods // @@ -306,7 +304,7 @@ class QuantBMmaBase { CUTLASS_HOST_DEVICE TensorRefQOffset operand_QOffset_ref() { - if constexpr (!kHasQOffset){ + if constexpr (!kHasQOffset) { return TensorRefQOffset(); } else { return TensorRefQOffset{operand_QOffset.data(), LayoutQOffset()}; @@ -315,7 +313,6 @@ class QuantBMmaBase { }; protected: - // // Data members // @@ -329,25 +326,21 @@ class QuantBMmaBase { /// Iterator to load a warp-scoped tile of quant scales from shared memory typename Operator::IteratorQMeta warp_tile_iterator_QScale_; -public: - + public: /// Construct from tensor references CUTLASS_DEVICE QuantBMmaBase( ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage &shared_storage, + SharedStorage& shared_storage, ///< ID within the threadblock int thread_idx, ///< ID of warp int warp_idx, ///< ID of each thread within a warp - int lane_idx - ): - warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx), - warp_tile_iterator_QScale_(shared_storage.operand_QScale_ref(), - shared_storage.operand_QOffset_ref(), lane_idx) - {} + int lane_idx) : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx), + warp_tile_iterator_QScale_(shared_storage.operand_QScale_ref(), + shared_storage.operand_QOffset_ref(), lane_idx) {} }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -397,9 +390,8 @@ template < int Stages, /// Used for partial specialization typename Enable = bool> -class QuantBMmaMultistage : - public QuantBMmaBase { -public: +class QuantBMmaMultistage : public QuantBMmaBase { + public: ///< Base class using Base = QuantBMmaBase; ///< Size of the Gemm problem - concept: gemm::GemmShape<> @@ -452,7 +444,6 @@ class QuantBMmaMultistage : /// Internal structure exposed for introspection. struct Detail { - /// Number of cp.async instructions to load one stage of operand A static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; @@ -497,11 +488,8 @@ class QuantBMmaMultistage : }; private: - - // Structure encapsulating pipeline state live from one iteration to the next struct PipeState { - using WarpLoadedFragmentA = typename Operator::FragmentA; using WarpLoadedFragmentB = typename Operator::FragmentB; using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; @@ -521,14 +509,12 @@ class QuantBMmaMultistage : WarpLoadedFragmentQScale warp_loaded_frag_QScale_; using WarpLoadedFragmentQOffset = typename std::conditional::type; + typename Operator::FragmentQOffset, + std::monostate>::type; WarpLoadedFragmentQOffset warp_loaded_frag_QOffset_; }; - private: - // // Data members // @@ -559,43 +545,39 @@ class QuantBMmaMultistage : /// Shared memory pointers for debug dumping static constexpr bool debug_layout = false; using LayoutDebugType = typename std::conditional, - std::monostate>::type; + QuantBLayoutDebug, + std::monostate>::type; LayoutDebugType layout_debug_; -public: - + public: /// Construct from tensor references CUTLASS_DEVICE QuantBMmaMultistage( ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, + typename Base::SharedStorage& shared_storage, ///< ID within the threadblock int thread_idx, ///< ID of warp int warp_idx, ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), - smem_iterator_QScale_(shared_storage.operand_QScale_ref(), thread_idx), - smem_iterator_QOffset_(shared_storage.operand_QOffset_ref(), thread_idx), - should_load_qscale_(thread_idx < IteratorQScale::ThreadMap::kThreads), - should_load_qoffset_(thread_idx >= IteratorQOffset::kThreadblockSize - IteratorQOffset::ThreadMap::kThreads), - smem_write_stage_idx_(0), - smem_read_stage_idx_(0) - { + int lane_idx) : Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_QScale_(shared_storage.operand_QScale_ref(), thread_idx), + smem_iterator_QOffset_(shared_storage.operand_QOffset_ref(), thread_idx), + should_load_qscale_(thread_idx < IteratorQScale::ThreadMap::kThreads), + should_load_qoffset_(thread_idx >= IteratorQOffset::kThreadblockSize - IteratorQOffset::ThreadMap::kThreads), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) { // Compute warp location within threadblock tile by mapping the warp_id to // three coordinates: // _m: the warp's position within the threadblock along the M dimension // _n: the warp's position within the threadblock along the N dimension // _k: the warp's position within the threadblock along the K dimension - if constexpr(debug_layout){ + if constexpr (debug_layout) { layout_debug_.smem_b_ptr_ = shared_storage.operand_B_ref().data(); layout_debug_.smem_qscale_ptr_ = shared_storage.operand_QScale_ref().data(); - if constexpr(kHasQOffset){ + if constexpr (kHasQOffset) { layout_debug_.smem_qoffset_ptr_ = shared_storage.operand_QOffset_ref().data(); } else { layout_debug_.smem_qoffset_ptr_ = nullptr; @@ -622,8 +604,7 @@ class QuantBMmaMultistage : /// Advance shared memory read-iterators to the next stage CUTLASS_DEVICE - void advance_smem_read_stage() - { + void advance_smem_read_stage() { ++smem_read_stage_idx_; if (smem_read_stage_idx_ == Base::kStages) { @@ -639,11 +620,10 @@ class QuantBMmaMultistage : /// Advance global memory read-iterators and shared memory write-iterators to the stage CUTLASS_DEVICE void advance_smem_write_stage( - IteratorA &iterator_A, - IteratorB &iterator_B, - IteratorQScale &iterator_QScale, - IteratorQOffset &iterator_QOffset) - { + IteratorA& iterator_A, + IteratorB& iterator_B, + IteratorQScale& iterator_QScale, + IteratorQOffset& iterator_QOffset) { // Advance global iterators iterator_A.add_tile_offset({0, 1}); iterator_B.add_tile_offset({1, 0}); @@ -675,7 +655,7 @@ class QuantBMmaMultistage : } CUTLASS_DEVICE - void copy_qscale_tiles(IteratorQScale &iterator_QScale){ + void copy_qscale_tiles(IteratorQScale& iterator_QScale) { // Quant scale matrix is 1/block_size of the B matrix, for a 64x64 warp tile, // it's only 64x64/block_size elements. For blocking size 16 ~ 64, it only // takes 4 ~ 16 cp.async instructions to load. One warp has 32 threads, so @@ -687,41 +667,41 @@ class QuantBMmaMultistage : "Quant scale should 1 access per vector!"); // Async Copy for quantization scale - typename IteratorQScale::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorQScale::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_QScale_.get()); constexpr int kSrcBytes = sizeof_bits::value * - IteratorQScale::ThreadMap::kElementsPerAccess / 8; + IteratorQScale::ThreadMap::kElementsPerAccess / 8; cutlass::arch::cp_async( dst_ptr, iterator_QScale.get(), iterator_QScale.valid()); } CUTLASS_DEVICE - void copy_qoffset_tiles(IteratorQOffset & iterator_QOffset) { + void copy_qoffset_tiles(IteratorQOffset& iterator_QOffset) { static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, "Quant offset should be loaded in one shot!"); static_assert(IteratorQOffset::kAccessesPerVector == 1, "Quant offset should 1 access per vector!"); - if constexpr(kHasQOffset) { + if constexpr (kHasQOffset) { // Async Copy for quantization offset - typename IteratorQOffset::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorQOffset::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_QOffset_.get()); constexpr int kSrcBytes = sizeof_bits::value * IteratorQOffset::ThreadMap::kElementsPerAccess / 8; cutlass::arch::cp_async( - dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); + dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); } } CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, + void copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start = 0) { auto group_start_A = group_start * Detail::kAccessesPerGroupA; iterator_A.set_iteration_index(group_start_A * @@ -732,8 +712,8 @@ class QuantBMmaMultistage : CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_A_.get()); int const kSrcBytes = sizeof_bits::value * @@ -763,8 +743,8 @@ class QuantBMmaMultistage : CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_B_.get()); int const kSrcBytes = sizeof_bits::value * @@ -789,16 +769,15 @@ class QuantBMmaMultistage : /// the global fragments needed by the first kStages-1 threadblock mainloop iterations CUTLASS_DEVICE void prologue( - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory - IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory - int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + IteratorA& iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB& iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale& iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset& iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int& gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining { // Issue several complete stages CUTLASS_PRAGMA_UNROLL for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - // Disable global fetching if done with global fetch iterations iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == 0); @@ -810,8 +789,8 @@ class QuantBMmaMultistage : // Async Copy for operand A CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_A_.get()); CUTLASS_PRAGMA_UNROLL @@ -838,8 +817,8 @@ class QuantBMmaMultistage : // Async Copy for operand B CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_B_.get()); CUTLASS_PRAGMA_UNROLL @@ -862,8 +841,8 @@ class QuantBMmaMultistage : static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, "Quant scale should be loaded in one shot!"); static_assert(IteratorQScale::kAccessesPerVector == 1, "Quant scale should 1 access per vector!"); - typename IteratorQScale::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorQScale::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_QScale_.get()); constexpr int kSrcBytes = @@ -881,13 +860,13 @@ class QuantBMmaMultistage : // Async Copy for quantization offset static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, "Quant offset should be loaded in one shot!"); static_assert(IteratorQOffset::kAccessesPerVector == 1, "Quant offset should 1 access per vector!"); - typename IteratorQOffset::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorQOffset::AccessType* dst_ptr = + reinterpret_cast( this->smem_iterator_QOffset_.get()); constexpr int kSrcBytes = sizeof_bits::value * - IteratorQOffset::ThreadMap::kElementsPerAccess / 8; + IteratorQOffset::ThreadMap::kElementsPerAccess / 8; cutlass::arch::cp_async( dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); @@ -901,22 +880,20 @@ class QuantBMmaMultistage : } } - /// Wait until we have at least one completed global fetch stage CUTLASS_DEVICE - void gmem_wait() - { + void gmem_wait() { // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) cutlass::arch::cp_async_wait(); __syncthreads(); - if constexpr(debug_layout) { + if constexpr (debug_layout) { if (LayoutDebugType::debug_smem && layout_debug_.block_id_ == 1) { - if (threadIdx.x == 0){ + if (threadIdx.x == 0) { printf("stage: %d\n", smem_write_stage_idx_); } cutlass::debug::dump_shmem(layout_debug_.smem_qscale_ptr_, Base::SharedStorage::ShapeQScale::kCount); - if constexpr(kHasQOffset){ + if constexpr (kHasQOffset) { cutlass::debug::dump_shmem(layout_debug_.smem_qoffset_ptr_, Base::SharedStorage::ShapeQScale::kCount); } } @@ -926,13 +903,13 @@ class QuantBMmaMultistage : /// Perform a threadblock mainloop iteration of matrix multiply-accumulate CUTLASS_DEVICE void mac_loop_iter( - PipeState &pipe_state, ///< [in|out] loop-carried pipeline state - FragmentC &accum, ///< [in|out] destination accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory - IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory - int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + PipeState& pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC& accum, ///< [in|out] destination accumulator tile + IteratorA& iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB& iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale& iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset& iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int& gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining { // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration CUTLASS_PRAGMA_UNROLL @@ -960,35 +937,34 @@ class QuantBMmaMultistage : iterator_B, (warp_mma_k + 1) % Base::kWarpGemmIterations); - if constexpr(debug_layout) { - if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + if constexpr (debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0) { printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); } LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); - if constexpr(kHasQOffset){ + if constexpr (kHasQOffset) { LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } } warp_mma_.transform( - pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_B_, - pipe_state.warp_loaded_frag_QScale_, - pipe_state.warp_loaded_frag_QOffset_); + pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); - if constexpr(debug_layout) { + if constexpr (debug_layout) { LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } // Execute the current warp-tile of MMA operations if (Detail::kStagedAccumulation) { warp_mma_( - pipe_state.tmp_accum_, - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.tmp_accum_ - ); + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_); if (warp_mma_k == 0) { plus plus_accum; @@ -997,11 +973,10 @@ class QuantBMmaMultistage : } } else { warp_mma_( - accum, - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - accum - ); + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum); } if (warp_mma_k == 0) { @@ -1025,51 +1000,50 @@ class QuantBMmaMultistage : iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == 0); iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); - if constexpr(kHasQOffset){ + if constexpr (kHasQOffset) { iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); } // Wait until we have at least one completed global fetch stage gmem_wait(); } - } } /// Specialized mainloop iteration of matrix multiply-accumulate, for small M CUTLASS_DEVICE void mac_loop_iter_small_m( - PipeState &pipe_state, ///< [in|out] loop-carried pipeline state - FragmentC &accum, ///< [in|out] destination accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory - IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory - int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + PipeState& pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC& accum, ///< [in|out] destination accumulator tile + IteratorA& iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB& iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale& iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset& iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int& gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining { // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration CUTLASS_PRAGMA_UNROLL for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { // In the case of small M, memory latency dominates. We try to move uses far // from their definitions to hide latency. - if constexpr(debug_layout) { - if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + if constexpr (debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0) { printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); } LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); - if constexpr(kHasQOffset){ + if constexpr (kHasQOffset) { LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } } warp_mma_.transform( - pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], - pipe_state.warp_loaded_frag_B_, - pipe_state.warp_loaded_frag_QScale_, - pipe_state.warp_loaded_frag_QOffset_); + pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); - if constexpr(debug_layout) { + if constexpr (debug_layout) { LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } @@ -1096,11 +1070,10 @@ class QuantBMmaMultistage : // Execute the current warp-tile of MMA operations if (Detail::kStagedAccumulation) { warp_mma_( - pipe_state.tmp_accum_, - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.tmp_accum_ - ); + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_); if (warp_mma_k == 0) { plus plus_accum; @@ -1109,11 +1082,10 @@ class QuantBMmaMultistage : } } else { warp_mma_( - accum, - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - accum - ); + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum); } // The second-to-last warp-tile also moves to the next global fetch stage @@ -1130,7 +1102,7 @@ class QuantBMmaMultistage : iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == 0); iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); - if constexpr(kHasQOffset){ + if constexpr (kHasQOffset) { iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); } @@ -1140,21 +1112,19 @@ class QuantBMmaMultistage : // Wait until we have at least one completed global fetch stage gmem_wait(); } - } } - /// Perform the specified number of threadblock mainloop iterations of matrix /// multiply-accumulate. Assumes prologue has been initiated. CUTLASS_DEVICE void gemm_iters( - int gemm_k_iterations, ///< number of threadblock mainloop iterations - FragmentC &accum, ///< [in|out] accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - IteratorQScale &iterator_QScale, ///< [in|out] iterator over QScale operand in global memory - IteratorQOffset &iterator_QOffset) ///< [in|out] iterator over QOffset operand in global memory + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC& accum, ///< [in|out] accumulator tile + IteratorA& iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB& iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale& iterator_QScale, ///< [in|out] iterator over QScale operand in global memory + IteratorQOffset& iterator_QOffset) ///< [in|out] iterator over QOffset operand in global memory { PipeState pipe_state; @@ -1162,7 +1132,7 @@ class QuantBMmaMultistage : iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == 0); iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); - if constexpr(kHasQOffset) { + if constexpr (kHasQOffset) { iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); } @@ -1183,26 +1153,26 @@ class QuantBMmaMultistage : copy_tiles_and_advance(iterator_A, iterator_B, 0); - if constexpr(Shape::kM > 32) { + if constexpr (Shape::kM > 32) { // the case of bigger m - if constexpr(debug_layout) { - if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + if constexpr (debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0) { printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, 0); } LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); - if constexpr(kHasQOffset){ + if constexpr (kHasQOffset) { LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } } warp_mma_.transform( - pipe_state.warp_transformed_frag_B_[0], - pipe_state.warp_loaded_frag_B_, - pipe_state.warp_loaded_frag_QScale_, - pipe_state.warp_loaded_frag_QOffset_); + pipe_state.warp_transformed_frag_B_[0], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); - if constexpr(debug_layout) { + if constexpr (debug_layout) { LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[0], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } } else { @@ -1218,24 +1188,24 @@ class QuantBMmaMultistage : // Mainloop CUTLASS_GEMM_LOOP for (; gemm_k_iterations > (-Base::kStages + 1);) { - if constexpr(Shape::kM > 32) { + if constexpr (Shape::kM > 32) { mac_loop_iter( - pipe_state, - accum, - iterator_A, - iterator_B, - iterator_QScale, - iterator_QOffset, - gemm_k_iterations); + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); } else { mac_loop_iter_small_m( - pipe_state, - accum, - iterator_A, - iterator_B, - iterator_QScale, - iterator_QOffset, - gemm_k_iterations); + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); } } @@ -1248,17 +1218,15 @@ class QuantBMmaMultistage : cutlass::arch::cp_async_fence(); cutlass::arch::cp_async_wait<0>(); __syncthreads(); - } - /// Perform a threadblock-scoped matrix multiply-accumulate CUTLASS_DEVICE void operator()( ///< problem size of GEMM int gemm_k_iterations, ///< destination accumulator tile - FragmentC &accum, + FragmentC& accum, ///< iterator over A operand in global memory IteratorA iterator_A, ///< iterator over B operand in global memory @@ -1268,8 +1236,7 @@ class QuantBMmaMultistage : ///< Iterator over quant offsets in global memory IteratorQOffset iterator_QOffset, ///< initial value of accumulator - FragmentC const &src_accum) { - + FragmentC const& src_accum) { // Prologue (start fetching iterations of global fragments into shared memory) prologue(iterator_A, iterator_B, iterator_QScale, iterator_QOffset, gemm_k_iterations); diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h index 2c49888c9450..858e616e9b5b 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h @@ -101,9 +101,9 @@ struct DefaultQuantBMmaTensorOp { ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace warp -} // namespace gemm -} // namespace cutlass +} // namespace warp +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h index 0b9ac0cb0e08..ce00c74103d7 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -33,21 +33,21 @@ //////////////////////////////////////////////////////////////////////////////// -namespace{ +namespace { -struct b32_pair{ +struct b32_pair { uint32_t a; uint32_t b; }; -struct fp16_quad{ +struct fp16_quad { cutlass::half_t a; cutlass::half_t b; cutlass::half_t c; cutlass::half_t d; }; -struct b16_quad{ +struct b16_quad { int16_t a; int16_t b; int16_t c; @@ -66,17 +66,15 @@ static_assert(sizeof(b64) == 8, "b64 should be 64 bits"); /// Convert packed 4b weights into fp16(weight + 16) /// Current bit hacking only supports fp16, need to add bf16 later. /// -template -CUTLASS_DEVICE -void weights2Half(cutlass::Array const &weights, - cutlass::Array& dest) -{ +template +CUTLASS_DEVICE void weights2Half(cutlass::Array const& weights, + cutlass::Array& dest) { static_assert(Size % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); uint32_t* dest_pair = reinterpret_cast(dest.data()); const uint32_t* w_oct = reinterpret_cast(weights.data()); CUTLASS_PRAGMA_UNROLL - for (int oct_idx = 0; oct_idx < Size/8; oct_idx++, w_oct++, dest_pair += 4){ + for (int oct_idx = 0; oct_idx < Size / 8; oct_idx++, w_oct++, dest_pair += 4) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) // static_cast(16 + weight) @@ -88,7 +86,7 @@ void weights2Half(cutlass::Array const &weights, " shl.b32 %1, %4, 2;\n" " shr.u32 %2, %4, 2;\n" " shr.u32 %3, %4, 6;\n" - " lop3.b32 %0, %0, 0x03c003c0, 0x4c004c00, 0xea;\n" // a & 0x03c0 | 0x4c00 + " lop3.b32 %0, %0, 0x03c003c0, 0x4c004c00, 0xea;\n" // a & 0x03c0 | 0x4c00 " lop3.b32 %1, %1, 0x03c003c0, 0x4c004c00, 0xea;\n" " lop3.b32 %2, %2, 0x03c003c0, 0x4c004c00, 0xea;\n" " lop3.b32 %3, %3, 0x03c003c0, 0x4c004c00, 0xea;\n" @@ -100,10 +98,9 @@ void weights2Half(cutlass::Array const &weights, assert(0); #endif } - } -} // namespace +} // namespace //////////////////////////////////////////////////////////////////////////////// @@ -117,18 +114,17 @@ namespace warp { // Since operand B is quantized on a per block basis, it's one meta data per block. template < - /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) - typename WarpShapeB_, - /// Block dimensions of the blockwise quantization. So the actual meta data - /// warp shape is WarpShapeB_ / BlockingShape_ - typename BlockingShape_, - /// Underlying matrix multiply operator (concept: arch::Mma) - typename ArchMmaOperator_, - /// Number of threads participating in one matrix operation - int Threads> -class QuantBMetaMmaTile{ -public: - + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTile { + public: using WarpShapeB = WarpShapeB_; using BlockingShape = BlockingShape_; using ArchMmaOperator = ArchMmaOperator_; @@ -169,8 +165,9 @@ class QuantBMetaMmaTile{ /// Number of core tiles per mma instruction, different from kBTilesPerMma when blocking size on K dimension /// exceeds the tile depth, so two tiles share the same meta data static int const kTilesPerMma = ((kBTilesPerMma == 2) && - (BlockingShape::kRow <= kNumBsPerCoreTileFragement * CoreTile::kContiguous)) - ? 2 : 1; + (BlockingShape::kRow <= kNumBsPerCoreTileFragement * CoreTile::kContiguous)) + ? 2 + : 1; /// stride to reach the meta data for the next CoreTile on the K dimension static int const kKTileStride = (kNumBsPerCoreTileFragement * CoreTile::kContiguous + BlockingShape::kRow - 1) / BlockingShape::kRow; @@ -190,24 +187,21 @@ class QuantBMetaMmaTile{ CUTLASS_DEVICE static MatrixCoord lane_position(int lane_id) { - if constexpr(kNumBsPerCoreTileFragement == 2 - && kBTilesPerMma == 2 - && BlockingShape::kRow == 1){ + if constexpr (kNumBsPerCoreTileFragement == 2 && kBTilesPerMma == 2 && BlockingShape::kRow == 1) { // Optimize for a special case of: // 16b gemm (kNumBsPerCoreTileFragement == 2) // 2 B operand tiles per mma (kBTilesPerMma == 2) // (1,n) quantization blocking // The scale and offset tensors are prepacked to reduce the number of load instructions. return make_Coord((lane_id % CoreTile::kContiguous) * 4, - lane_id / CoreTile::kContiguous); + lane_id / CoreTile::kContiguous); } else { return make_Coord((lane_id % CoreTile::kContiguous) * kNumBsPerCoreTileFragement, - lane_id / CoreTile::kContiguous); + lane_id / CoreTile::kContiguous); } } }; - //////////////////////////////////////////////////////////////////////////////// /// This tile iterator is to load quantization meta data for operand B from @@ -222,25 +216,25 @@ class QuantBMetaMmaTile{ /// out the operand B layout in the tensor core. /// template < - /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) - typename WarpShapeB_, - /// Block dimensions of the blockwise quantization. So the actual meta data - /// warp shape is WarpShapeB_ / BlockingShape_ - typename BlockingShape_, - /// Data type of the quant scales - typename ElementScale_, - /// Layout of the quant scales - typename LayoutScale_, - /// Data type of quant offsets - typename ElementOffset_, - /// Layout of quant offsets - typename LayoutOffset_, - /// Underlying matrix multiply operator (concept: arch::Mma) - typename ArchMmaOperator_, - /// Number of threads participating in one matrix operation - int Threads, - /// Number of partitions along K dimension - int PartitionsK_ = 1> + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the quant scales + typename ElementScale_, + /// Layout of the quant scales + typename LayoutScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Layout of quant offsets + typename LayoutOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads, + /// Number of partitions along K dimension + int PartitionsK_ = 1> class QuantBMetaMmaTensorOpTileIterator; //////////////////////////////////////////////////////////////////////////////// @@ -248,25 +242,24 @@ class QuantBMetaMmaTensorOpTileIterator; /// Specialization for column major layout template < - /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) - typename WarpShapeB_, - /// Block dimensions of the blockwise quantization. So the actual meta data - /// warp shape is WarpShapeB_ / BlockingShape_ - typename BlockingShape_, - /// Data type of the meta data elements - typename ElementScale_, - /// Data type of quant offsets - typename ElementOffset_, - /// Underlying matrix multiply operator (concept: arch::Mma) - typename ArchMmaOperator_, - /// Number of threads participating in one matrix operation - int Threads> + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> class QuantBMetaMmaTensorOpTileIterator{ -public: - + ElementScale_, cutlass::layout::ColumnMajor, + ElementOffset_, cutlass::layout::ColumnMajor, + ArchMmaOperator_, Threads, 1> { + public: using WarpShapeB = WarpShapeB_; using BlockingShape = BlockingShape_; using ElementScale = ElementScale_; @@ -277,7 +270,7 @@ class QuantBMetaMmaTensorOpTileIterator::value); static_assert(BlockingShape::kRow == 1 && BlockingShape::kColumn > 1, - "Only support row blocking for column major layout"); + "Only support row blocking for column major layout"); using MetaTile = QuantBMetaMmaTile; @@ -316,44 +309,39 @@ class QuantBMetaMmaTensorOpTileIterator; using FragmentOffset = typename std::conditional, - std::monostate>::type; + Array, + std::monostate>::type; using AccessTypeScale = Array; using AccessTypeOffset = Array; -private: - - ElementScale *pointer_; + private: + ElementScale* pointer_; Layout layout_; - ElementOffset *pointer_offset_; + ElementOffset* pointer_offset_; Layout layout_offset_; TensorCoord lane_position_; -public: - + public: CUTLASS_DEVICE - QuantBMetaMmaTensorOpTileIterator() { } + QuantBMetaMmaTensorOpTileIterator() {} CUTLASS_DEVICE QuantBMetaMmaTensorOpTileIterator( - TensorRefScale const &ref, - TensorRefOffset const &ref_offset, - int lane_idx - ): - pointer_(ref.data()), - layout_(ref.layout()), - pointer_offset_(ref_offset.data()), - layout_offset_(ref_offset.layout()), - lane_position_(MetaTile::lane_position(lane_idx)){} + TensorRefScale const& ref, + TensorRefOffset const& ref_offset, + int lane_idx) : pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)) {} /// Loads a fragment CUTLASS_HOST_DEVICE - void load(FragmentScale &frag, FragmentOffset &frag_offset) { - if constexpr(kNumBsPerCoreTileFragement == 2 - && kBTilesPerMma == 2){ + void load(FragmentScale& frag, FragmentOffset& frag_offset) { + if constexpr (kNumBsPerCoreTileFragement == 2 && kBTilesPerMma == 2) { // Optimize for a special case of: // 16b gemm (kNumBsPerCoreTileFragement == 2) // 2 B operand tiles per mma (kBTilesPerMma == 2) @@ -362,19 +350,19 @@ class QuantBMetaMmaTensorOpTileIterator *dst_ptr = reinterpret_cast*>(frag.data()); + Array* dst_ptr = reinterpret_cast*>(frag.data()); CUTLASS_PRAGMA_UNROLL - for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ - Array *src_ptr = reinterpret_cast*>(pointer_ + layout_({row, c})); + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride) { + Array* src_ptr = reinterpret_cast*>(pointer_ + layout_({row, c})); *dst_ptr = *src_ptr; dst_ptr++; } - if constexpr(kHasOffset){ - Array *dst_ptr_offset = reinterpret_cast*>(frag_offset.data()); + if constexpr (kHasOffset) { + Array* dst_ptr_offset = reinterpret_cast*>(frag_offset.data()); CUTLASS_PRAGMA_UNROLL - for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ - Array *src_ptr_offset = reinterpret_cast*>(pointer_offset_ + layout_offset_({row, c})); + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride) { + Array* src_ptr_offset = reinterpret_cast*>(pointer_offset_ + layout_offset_({row, c})); *dst_ptr_offset = *src_ptr_offset; dst_ptr_offset++; } @@ -388,21 +376,21 @@ class QuantBMetaMmaTensorOpTileIterator(frag.data()); CUTLASS_PRAGMA_UNROLL - for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride) { CUTLASS_PRAGMA_UNROLL - for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride) { AccessTypeScale* src_ptr = reinterpret_cast(pointer_ + layout_({r, c})); *dst_ptr = *src_ptr; dst_ptr++; } } - if constexpr(kHasOffset){ + if constexpr (kHasOffset) { AccessTypeOffset* dst_ptr = reinterpret_cast(frag_offset.data()); CUTLASS_PRAGMA_UNROLL - for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride) { CUTLASS_PRAGMA_UNROLL - for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride) { AccessTypeOffset* src_ptr = reinterpret_cast(pointer_offset_ + layout_offset_({r, c})); *dst_ptr = *src_ptr; dst_ptr++; @@ -413,18 +401,17 @@ class QuantBMetaMmaTensorOpTileIterator - CUTLASS_HOST_DEVICE - static Array debug_expand(Array const &frag){ + CUTLASS_HOST_DEVICE static Array debug_expand(Array const& frag) { Array ret; int out_idx = 0; CUTLASS_PRAGMA_UNROLL - for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + for (int n_out = 0; n_out < kMmaIterationsB; n_out++) { int n_idx = n_out / kNRepeats; CUTLASS_PRAGMA_UNROLL - for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++) { int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); CUTLASS_PRAGMA_UNROLL - for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++) { int elem_idx = elem_out_idx / BlockingShape::kRow; int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma; ret[out_idx] = frag[idx]; @@ -436,17 +423,17 @@ class QuantBMetaMmaTensorOpTileIterator const &weights, - Array& dest){ + static void dequant(FragmentScale const& scales, + FragmentOffset const& fragment_offsets, + Array const& weights, + Array& dest) { static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm."); static_assert(kExpandedSize % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); // First convert 4b weight into fp16(weight + 16) weights2Half(weights, dest); - if constexpr(kBTilesPerMma == 2){ + if constexpr (kBTilesPerMma == 2) { // Optimize for a special case of: // 2 B operand tiles per mma (kBTilesPerMma == 2) // (1,n) quantization blocking (BlockingShape::kRow == 1) @@ -454,28 +441,30 @@ class QuantBMetaMmaTensorOpTileIterator(dest.data()); const b64* scales_ptr = reinterpret_cast(scales.data()); [[maybe_unused]] const ElementOffset* fragment_offsets_ptr = nullptr; - if constexpr(kHasOffset) { fragment_offsets_ptr = fragment_offsets.data(); } + if constexpr (kHasOffset) { + fragment_offsets_ptr = fragment_offsets.data(); + } CUTLASS_PRAGMA_UNROLL - for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++) { // dequantize: d = scale * (weight - offset) // to use FMA, d = scale * weight + (scale * (-offset)) [[maybe_unused]] b64 offsets{0}; - if constexpr(kHasOffset) { + if constexpr (kHasOffset) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) const uint32_t* p = reinterpret_cast(fragment_offsets_ptr); asm volatile( "{\n\t" - " .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands + " .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands // static_cast(-16 - offset) // input [d, b, c, a], - " shl.b32 rb0, %4, 6;\n" // rb0 = [x, b, x, a] << 6 - " shr.u32 rb1, %4, 2;\n" // rb1 = [x, d, x, c] << 6 - " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " shl.b32 rb0, %4, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, %4, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" - " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) " mul.rn.f16x2 %1, %3, rb1;\n" "}\n" : "=r"(offsets.pair.a), "=r"(offsets.pair.b) @@ -492,25 +481,25 @@ class QuantBMetaMmaTensorOpTileIteratorpair.a), "r"(scales_ptr->pair.b)); #else - offsets.fp16_quad.a = scales_ptr->fp16_quad.a * static_cast(-16-8); - offsets.fp16_quad.b = scales_ptr->fp16_quad.b * static_cast(-16-8); - offsets.fp16_quad.c = scales_ptr->fp16_quad.c * static_cast(-16-8); - offsets.fp16_quad.d = scales_ptr->fp16_quad.d * static_cast(-16-8); + offsets.fp16_quad.a = scales_ptr->fp16_quad.a * static_cast(-16 - 8); + offsets.fp16_quad.b = scales_ptr->fp16_quad.b * static_cast(-16 - 8); + offsets.fp16_quad.c = scales_ptr->fp16_quad.c * static_cast(-16 - 8); + offsets.fp16_quad.d = scales_ptr->fp16_quad.d * static_cast(-16 - 8); #endif } CUTLASS_PRAGMA_UNROLL - for (int n_r = 0; n_r < kNRepeats; n_r++){ + for (int n_r = 0; n_r < kNRepeats; n_r++) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) asm volatile( "{\n\t" - " fma.rn.f16x2 %0, %2, %0, %4;\n" // dest = scale * (16 + weight) + (scale * (-16 - offset)) + " fma.rn.f16x2 %0, %2, %0, %4;\n" // dest = scale * (16 + weight) + (scale * (-16 - offset)) " fma.rn.f16x2 %1, %3, %1, %5;\n" "}\n" : "+r"(dest_pair[0]), "+r"(dest_pair[1]) @@ -529,75 +518,70 @@ class QuantBMetaMmaTensorOpTileIterator(-16 - static_cast(fragment_offsets[idx])); } else { - offset = s * static_cast(-16-8); + offset = s * static_cast(-16 - 8); } dest[out_idx] = s * dest[out_idx] + offset; out_idx++; } } } - } - } /// Advances the pointer CUTLASS_HOST_DEVICE - QuantBMetaMmaTensorOpTileIterator &operator++() { + QuantBMetaMmaTensorOpTileIterator& operator++() { // This is for operand B, so advance on the K dimension lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); return *this; } CUTLASS_DEVICE - QuantBMetaMmaTensorOpTileIterator &add_tile_offset( - TensorCoord const &tile_offset) { + QuantBMetaMmaTensorOpTileIterator& add_tile_offset( + TensorCoord const& tile_offset) { int rows = tile_offset.row() * MetaTile::TileShapeB::kRow; int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn; lane_position_ += TensorCoord(rows, columns); return *this; } - }; - //////////////////////////////////////////////////////////////////////////////// /// Specialization for row major layout template < - /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) - typename WarpShapeB_, - /// Block dimensions of the blockwise quantization. So the actual meta data - /// warp shape is WarpShapeB_ / BlockingShape_ - typename BlockingShape_, - /// Data type of the meta data elements - typename ElementScale_, - /// Data type of quant offsets - typename ElementOffset_, - /// Underlying matrix multiply operator (concept: arch::Mma) - typename ArchMmaOperator_, - /// Number of threads participating in one matrix operation - int Threads> + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> class QuantBMetaMmaTensorOpTileIterator{ -public: - + ElementScale_, cutlass::layout::RowMajor, + ElementOffset_, cutlass::layout::RowMajor, + ArchMmaOperator_, Threads, 1> { + public: using WarpShapeB = WarpShapeB_; using BlockingShape = BlockingShape_; using ElementScale = ElementScale_; @@ -608,7 +592,7 @@ class QuantBMetaMmaTensorOpTileIterator::value); static_assert(BlockingShape::kColumn == 1 && BlockingShape::kRow > 1, - "Only support column blocking for row major layout"); + "Only support column blocking for row major layout"); using MetaTile = QuantBMetaMmaTile; @@ -647,40 +631,35 @@ class QuantBMetaMmaTensorOpTileIterator; using FragmentOffset = typename std::conditional, - std::monostate>::type; + Array, + std::monostate>::type; -private: - - ElementScale *pointer_; + private: + ElementScale* pointer_; Layout layout_; - ElementOffset *pointer_offset_; + ElementOffset* pointer_offset_; Layout layout_offset_; TensorCoord lane_position_; -public: - + public: CUTLASS_DEVICE - QuantBMetaMmaTensorOpTileIterator() { } + QuantBMetaMmaTensorOpTileIterator() {} CUTLASS_DEVICE QuantBMetaMmaTensorOpTileIterator( - TensorRefScale const &ref, - TensorRefOffset const &ref_offset, - int lane_idx - ): - pointer_(ref.data()), - layout_(ref.layout()), - pointer_offset_(ref_offset.data()), - layout_offset_(ref_offset.layout()), - lane_position_(MetaTile::lane_position(lane_idx)) - {} + TensorRefScale const& ref, + TensorRefOffset const& ref_offset, + int lane_idx) : pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)) {} /// Loads a fragment CUTLASS_HOST_DEVICE - void load(FragmentScale &frag, FragmentOffset &frag_offset) { + void load(FragmentScale& frag, FragmentOffset& frag_offset) { const int row = lane_position_.row() / BlockingShape::kRow; const int column = lane_position_.column() / BlockingShape::kColumn; static_assert(kTilesPerMma * kCoreTileFragementSize == 1, "Only support one meta data per core tile"); @@ -688,34 +667,33 @@ class QuantBMetaMmaTensorOpTileIterator - CUTLASS_HOST_DEVICE - static Array debug_expand(Array const &frag){ + CUTLASS_HOST_DEVICE static Array debug_expand(Array const& frag) { Array ret; int out_idx = 0; CUTLASS_PRAGMA_UNROLL - for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + for (int n_out = 0; n_out < kMmaIterationsB; n_out++) { int n_idx = n_out / kNRepeats; CUTLASS_PRAGMA_UNROLL - for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++) { int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); CUTLASS_PRAGMA_UNROLL - for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++) { int elem_idx = elem_out_idx / BlockingShape::kRow; int col = elem_idx + mma_tile_idx * kCoreTileFragementSize; int idx = col * kMmaIterations + n_idx; @@ -728,10 +706,10 @@ class QuantBMetaMmaTensorOpTileIterator const &weights, - Array& dest){ + static void dequant(FragmentScale const& scales, + FragmentOffset const& offsets, + Array const& weights, + Array& dest) { static_assert(kNRepeats == 1, "This is implied by BlockingShape::kColumn == 1"); static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm now."); @@ -742,30 +720,30 @@ class QuantBMetaMmaTensorOpTileIterator(scales.data()); uint32_t* addon_ptr = reinterpret_cast(addon); - if constexpr(kHasOffset){ + if constexpr (kHasOffset) { const uint32_t* p = reinterpret_cast(offsets.data()); CUTLASS_PRAGMA_UNROLL - for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){ + for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) asm volatile( - "{\n\t" - " .reg .b32 rb0, rb1, rb2;\n" + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" - // offset from [d, c, b, a] --> [d, b, c, a] - " prmt.b32 rb2, %4, rb0, 0x3120;\n" + // offset from [d, c, b, a] --> [d, b, c, a] + " prmt.b32 rb2, %4, rb0, 0x3120;\n" - // static_cast(-16 - offset) - // input [d, b, c, a], - " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 - " shr.u32 rb1, rb2, 2;\n" // rb1 = [x, d, x, c] << 6 - " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 - " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" - " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) - " mul.rn.f16x2 %1, %3, rb1;\n" - "}\n" - : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) - : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), - "r"(p[0])); + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, rb2, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %1, %3, rb1;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(p[0])); #else assert(0); #endif @@ -775,17 +753,17 @@ class QuantBMetaMmaTensorOpTileIterator= 800)) asm volatile( - "{\n\t" - " .reg .b32 rb0;\n" - " mov.u32 rb0, 0xce00ce00;\n" - " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) - " mul.rn.f16x2 %1, %3, rb0;\n" - "}\n" - : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) - : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) + " mul.rn.f16x2 %1, %3, rb0;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); #else assert(0); #endif @@ -794,7 +772,7 @@ class QuantBMetaMmaTensorOpTileIterator= 800)) const uint32_t* scales_ptr = reinterpret_cast(scales.data()); uint32_t* addon_ptr = reinterpret_cast(addon); @@ -802,21 +780,20 @@ class QuantBMetaMmaTensorOpTileIterator(offsets.data()); asm volatile( - "{\n\t" - " .reg .b32 rb0, rb1, rb2;\n" - - // offset from [?, ?, b, a] --> [?, b, ?, a] - " prmt.b32 rb2, %2, rb0, 0x3120;\n" - - // static_cast(-16 - offset) - // input [d, b, c, a], - " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 - " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 - " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - offset) - "}\n" - : "=r"(addon_ptr[0]) - : "r"(scales_ptr[0]) - "r"(p[0])); + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" + + // offset from [?, ?, b, a] --> [?, b, ?, a] + " prmt.b32 rb2, %2, rb0, 0x3120;\n" + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - offset) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0]) "r"(p[0])); #else assert(0); #endif @@ -825,32 +802,32 @@ class QuantBMetaMmaTensorOpTileIterator(scales.data()); uint32_t* addon_ptr = reinterpret_cast(addon); asm volatile( - "{\n\t" - " .reg .b32 rb0;\n" - " mov.u32 rb0, 0xce00ce00;\n" - " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - 8) - "}\n" - : "=r"(addon_ptr[0]) - : "r"(scales_ptr[0])); + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - 8) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0])); #else assert(0); #endif } } else { // kMmaIterationsB == 1 - if constexpr(kHasOffset){ + if constexpr (kHasOffset) { uint8_t zp = offsets[0]; addon[0] = scales[0] * static_cast(-16 - static_cast(zp)); } else { - addon[0] = scales[0] * static_cast(-16-8); + addon[0] = scales[0] * static_cast(-16 - 8); } } int out_idx = 0; CUTLASS_PRAGMA_UNROLL - for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + for (int n_out = 0; n_out < kMmaIterationsB; n_out++) { CUTLASS_PRAGMA_UNROLL - for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++) { dest[out_idx] = scales[n_out] * dest[out_idx] + addon[n_out]; dest[out_idx + 1] = scales[n_out] * dest[out_idx + 1] + addon[n_out]; out_idx += 2; @@ -860,24 +837,22 @@ class QuantBMetaMmaTensorOpTileIterator - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Data type of quant scales - typename ElementQScale_, - /// Layout of quant scales (concept: MatrixLayout) - typename SmemLayoutQScale_, - /// Data type of quant offsets - typename ElementQOffset_, - /// Layout of quant offsets (concept: MatrixLayout) - typename SmemLayoutQOffset_, - /// Blocking dimensions of quantization - typename QuantBlocking_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Number of partitions along K dimension - int PartitionsK_ = 1, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Used for partial specialization - typename Enable = bool -> + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Data type of quant scales + typename ElementQScale_, + /// Layout of quant scales (concept: MatrixLayout) + typename SmemLayoutQScale_, + /// Data type of quant offsets + typename ElementQOffset_, + /// Layout of quant offsets (concept: MatrixLayout) + typename SmemLayoutQOffset_, + /// Blocking dimensions of quantization + typename QuantBlocking_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool> class QuantBMmaTensorOp { -public: + public: /// Shape of warp-level matrix operation (concept: GemmShape) using Shape = Shape_; @@ -157,13 +156,12 @@ class QuantBMmaTensorOp { /// Number of partitions along K dimension static int const kPartitionsK = PartitionsK_; -public: - + public: /// Iterates over the A operand in memory using IteratorA = MmaTensorOpMultiplicandTileIterator< - MatrixShape, Operand::kA, ElementA, LayoutA, - MatrixShape, - Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + MatrixShape, Operand::kA, ElementA, LayoutA, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; /// Storage for A tile using FragmentA = typename IteratorA::Fragment; @@ -174,8 +172,8 @@ class QuantBMmaTensorOp { /// Iterates over the B operand in memory using IteratorB = MmaTensorOpMultiplicandTileIterator< - MatrixShape, Operand::kB, ElementB, LayoutB, - MatrixShape, + MatrixShape, Operand::kB, ElementB, LayoutB, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; // warp B MatrixShape<64, 64>, // layout B cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<16, 64>, @@ -184,7 +182,7 @@ class QuantBMmaTensorOp { // FragmentB::kElements 32 /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; // cutlass::Array + using FragmentB = typename IteratorB::Fragment; // cutlass::Array /// Storage for transformed B tile /// When loading weights, we packed 4 int4 weights into one 2-byte-element, when expanded @@ -196,8 +194,8 @@ class QuantBMmaTensorOp { /// Iterates over the C operand in memory using IteratorC = MmaTensorOpAccumulatorTileIterator< - MatrixShape, ElementC, LayoutC, - typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + MatrixShape, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; /// Storage for C tile using FragmentC = typename IteratorC::Fragment; @@ -218,26 +216,23 @@ class QuantBMmaTensorOp { // TODO This is an expanding iterator, it needs to replicate the quantization parameters // to all threads in the warp. using IteratorQMeta = QuantBMetaMmaTensorOpTileIterator< - MatrixShape, QuantBlocking, ElementQScale, SmemLayoutQScale, - ElementQOffset, SmemLayoutQOffset, - ArchMmaOperator, kThreadCount, kPartitionsK>; + MatrixShape, QuantBlocking, ElementQScale, SmemLayoutQScale, + ElementQOffset, SmemLayoutQOffset, + ArchMmaOperator, kThreadCount, kPartitionsK>; using FragmentQScale = typename IteratorQMeta::FragmentScale; using FragmentQOffset = typename IteratorQMeta::FragmentOffset; /// Number of mma operations performed using MmaIterations = MatrixShape< - (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN - >; - -public: + (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + public: /// Underlying matrix multiply operator (concept: arch::Mma) ArchMmaOperator mma; -public: - + public: // // Methods // @@ -249,113 +244,106 @@ class QuantBMmaTensorOp { /// Performs a warp-level matrix multiply-accumulate operation CUTLASS_DEVICE void operator()( - FragmentC &D, - TransformedFragmentA const &A, - TransformedFragmentB const &B, - FragmentC const &C - ) const { - + FragmentC& D, + TransformedFragmentA const& A, + TransformedFragmentB const& B, + FragmentC const& C) const { using MmaOperandA = typename ArchMmaOperator::FragmentA; using MmaOperandB = typename ArchMmaOperator::FragmentB; using MmaOperandC = typename ArchMmaOperator::FragmentC; D = C; - MmaOperandA const *ptr_A = reinterpret_cast(&A); - MmaOperandB const *ptr_B = reinterpret_cast(&B); - MmaOperandC *ptr_D = reinterpret_cast(&D); - - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - // The visitation order is like - // _ - // | | | | - // | | | | - // |_| |_| - // - // Down Up Down Up - + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + // The visitation order is like + // _ + // | | | | + // | | | | + // |_| |_| + // + // Down Up Down Up + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma( + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n], ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } else { - mma( + } else { + mma( ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n], ptr_D[m_serpentine + n * MmaIterations::kRow]); - } } } - #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - // The visitation order is like - // _________ - // _________| - // |_________ - // __________| - // - // Right Left Right Left - + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + // The visitation order is like + // _________ + // _________| + // |_________ + // __________| + // + // Right Left Right Left + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma( + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine], ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } else { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow]); } } - #else - assert(0); - #endif + } +#else + assert(0); +#endif } /// Transform the mma operands to the required types CUTLASS_DEVICE - void transform(TransformedFragmentB &dst_B, - FragmentB const &B, - FragmentQScale const &scales, - FragmentQOffset const &offsets) const { - - Array const *ptr_B = - reinterpret_cast const *>(&B); + void transform(TransformedFragmentB& dst_B, + FragmentB const& B, + FragmentQScale const& scales, + FragmentQOffset const& offsets) const { + Array const* ptr_B = + reinterpret_cast const*>(&B); IteratorQMeta::dequant(scales, offsets, *ptr_B, dst_B); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace warp -} // namespace gemm -} // namespace cutlass +} // namespace warp +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// -//#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" +// #include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 4b852be951c9..81789386a320 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -16,10 +16,11 @@ Module Name: --*/ #include "sqnbitgemm.h" -#include "sqnbitgemm_q8_block.h" #include +#include "sqnbitgemm_q8_block.h" + namespace { @@ -80,7 +81,7 @@ MlasIsSQNBitGemmAvailable( Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; } case SQNBitGemmVariant_BitWidth4_CompInt8: { - return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr && + return Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr; } default: { @@ -372,15 +373,17 @@ SQ4BitGemm_CompFp32( if (bias) { AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); } + if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, RowsHandled, CountN, ldc ); } c_blk += ldc * RowsHandled; a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; } } @@ -431,36 +434,6 @@ SQ4BitGemm_CompInt8( const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; - if (RangeCountM == 1) { - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, size_t{128}); - - const std::byte* a_row = QuantA; - const std::byte* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const std::byte* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( - BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias - ); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc - ); - } - } - return; - } - - // This is a naive M > 1 implementation that repeatedly calls the M=1 kernel. - // TODO Replace it with an optimized implementation. size_t CountN; for (size_t n = 0; n < RangeCountN; n += CountN) { CountN = std::min(RangeCountN - n, size_t{128}); @@ -473,21 +446,24 @@ SQ4BitGemm_CompInt8( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - for (size_t m = 0; m < RangeCountM; ++m) { - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias ); if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, + RowsHandled, CountN, ldc ); } - c_blk += ldc; - a_row += lda; + c_blk += RowsHandled * ldc; + a_row += RowsHandled * lda; + + RowsRemaining -= RowsHandled; } } } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index effb59b250ca..8321dcc217e9 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -184,7 +184,6 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { /** * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. * A and B are block quantized and B is column major. - * This kernel handles the special case where M, the number of rows of A and C, is 1. * * @param BlkLen Number of values in a block. * @param QuantA Supplies the quantized A matrix. @@ -193,25 +192,31 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param QuantBScale Supplies the quantized B matrix block scale values. * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. * @param[out] C Supplies the output C matrix. - * @param CountN Number of columns of B and C. + * @param CountM Number of rows of A and C to process, an upper bound. + * @param CountN Number of columns of B and C to process. * @param CountK Number of columns of A and rows of B. - * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param BlockCountK Number of blocks in one row of A and one column of B. + * @param ldc Number of elements between adjacent rows of C. * @param Bias Bias vector of length N. + * + * @return The number of rows of A and C that were processed, at most CountM. */ - typedef void(SQ4BitGemmM1Kernel_CompInt8_Fn)( + typedef size_t(SQ4BitGemmKernel_CompInt8_Fn)( size_t BlkLen, const std::byte* QuantA, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, float* C, + size_t CountM, size_t CountN, size_t CountK, - size_t BlockStrideQuantB, + size_t BlockCountK, + size_t ldc, const float* Bias ); - SQ4BitGemmM1Kernel_CompInt8_Fn* SQ4BitGemmM1Kernel_CompInt8 = nullptr; + SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr; /** * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index be573381c39c..0922f5ef646b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -434,6 +434,44 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( } } +size_t +SQ4BitGemmKernel_CompInt8_avx2( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + MLAS_UNREFERENCED_PARAMETER(ldc); + + if (CountM == 0) { + return 0; + } + + SQ4BitGemmM1Kernel_CompInt8_avx2( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockCountK, + Bias + ); + + return 1; +} + template MLAS_FORCEINLINE void ComputeDotProducts_BlkLen16_CompFp32_avx2( @@ -1109,7 +1147,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; + d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx2; return d; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 0099b61d8196..b86890676070 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -239,7 +239,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; + d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx512; return d; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 27310d825334..6477a2019b21 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -237,6 +237,44 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( } } +size_t +SQ4BitGemmKernel_CompInt8_avx512vnni( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + MLAS_UNREFERENCED_PARAMETER(ldc); + + if (CountM == 0) { + return 0; + } + + SQ4BitGemmM1Kernel_CompInt8_avx512vnni( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockCountK, + Bias + ); + + return 1; +} + void MLASCALL MlasQ80BlkQuantRow_avx512( size_t BlkLen, @@ -260,7 +298,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx512vnni; + d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx512vnni; d.QuantizeARow_CompInt8 = MlasQ80BlkQuantRow_avx512; return d; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index cfc0564cd041..706e08fc467b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -158,17 +158,19 @@ Q4BitBlkDequantBForSgemm_CompFp32_avx2( const size_t BlockStrideQuantB ); -void -SQ4BitGemmM1Kernel_CompInt8_avx2( +size_t +SQ4BitGemmKernel_CompInt8_avx2( size_t BlkLen, const std::byte* QuantA, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, float* C, + size_t CountM, size_t CountN, size_t CountK, - size_t BlockStrideQuantB, + size_t BlockCountK, + size_t ldc, const float* Bias ); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 6d1864794f94..3f32cc6c5312 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm_kernel_neon.h + sqnbitgemm_kernel_neon.cpp Abstract: @@ -17,20 +17,22 @@ Module Name: #include -#include #include -#include #include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" -// -// Quantized B data packing function implementation. -// +namespace sqnbitgemm_neon +{ namespace { +// +// Quantized B data packing function implementation. +// + size_t SQ4BitGemmPackQuantBDataSize( size_t N, @@ -134,7 +136,7 @@ SQ4BitGemmPerGemmWorkspaceSize( { MLAS_UNREFERENCED_PARAMETER(N); - switch(ComputeType) { + switch (ComputeType) { case CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); @@ -167,1316 +169,7 @@ SQ4BitGemmPerGemmWorkspaceAlignment( } // namespace -// -// General helpers. -// - -namespace -{ - -template -MLAS_FORCEINLINE void -UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) -{ - (f(Indices), ...); -} - -template -MLAS_FORCEINLINE void -UnrolledLoop(IterationFn&& f) -{ - UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); -} - -MLAS_FORCEINLINE void -Transpose4x4(float32x4_t& a0, float32x4_t& a1, float32x4_t& a2, float32x4_t& a3) -{ - // aN: aN_0 aN_1 aN_2 aN_3 - - float32x4_t b0 = vzip1q_f32(a0, a1); // a0_0 a1_0 a0_1 a1_1 - float32x4_t b1 = vzip2q_f32(a0, a1); // a0_2 a1_2 a0_3 a1_3 - float32x4_t b2 = vzip1q_f32(a2, a3); // a2_0 a3_0 a2_1 a3_1 - float32x4_t b3 = vzip2q_f32(a2, a3); // a2_2 a3_2 a2_3 a3_3 - - // a0_0 a1_0 a2_0 a3_0 - a0 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); - // a0_1 a1_1 a2_1 a3_1 - a1 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); - // a0_2 a1_2 a3_2 a3_2 - a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); - // a0_3 a1_3 a2_3 a3_3 - a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); -} - -MLAS_FORCEINLINE float32x4_t -FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) -{ - Transpose4x4(a0, a1, a2, a3); - return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); -} - -template -MLAS_FORCEINLINE void -LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) -{ - static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); - - assert(count <= Capacity); - - size_t vi = 0; // vector index - - // handle 4 values at a time - while (count > 3) { - dst[vi] = vld1q_f32(src); - - vi += 1; - src += 4; - count -= 4; - } - - // handle remaining values - if (count > 0) { - dst[vi] = vsetq_lane_f32(src[0], dst[vi], 0); - - if (count > 1) { - dst[vi] = vsetq_lane_f32(src[1], dst[vi], 1); - - if (count > 2) { - dst[vi] = vsetq_lane_f32(src[2], dst[vi], 2); - } - } - } -} - -} // namespace - -// -// CompFp32 kernel implementation. -// - -namespace -{ - -namespace fp32_conversion -{ - -// Manual conversion to float takes place in two steps: -// 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f]. -// This target float range is convenient because the 4-bit source values can be placed directly into the -// target float bits. -// 2. Subtract the conversion offset of 16 from the float result. - -// The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values. -constexpr uint16_t float_high_half_template = 0b0'10000011'0000000; -// sign|exponent|partial mantissa -// +|131: 2^4|~~~~ <- 4 bits go here - -const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template); - -constexpr float offset = 16.0f; - -} // namespace fp32_conversion - -template -MLAS_FORCEINLINE void -ComputeDotProducts_BlkBitWidth4_CompFp32( - size_t BlkLen, - const float* ARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - float* SumPtr, - size_t CountK, - size_t StrideQuantBData, - size_t StrideQuantBScale, - size_t StrideQuantBZeroPoint, - const float* BiasPtr -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t SubBlkLen = 16; - - static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); - - assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); - - const uint8x8_t LowMask = vdup_n_u8(0x0F); - - float32x4_t acc[NCols]{}; - - const std::byte* QuantBData = QuantBDataColPtr; - const float* QuantBScale = QuantBScaleColPtr; - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - // only used if HasZeroPoint is true - - for (size_t k = 0; k < CountK; k += BlkLen) { - const size_t k_blk_len = std::min(CountK - k, BlkLen); - - float scale[NCols]; - UnrolledLoop( - [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } - ); - - [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset. - // only used if HasZeroPoint is true - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const std::byte zp_packed = - QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[i] = fp32_conversion::offset + std::to_integer(zp); - }); - } - - for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { - // load A row vector elements - - // load `SubBlkLen` elements from A, padded with 0's if there aren't enough - const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); - float32x4_t av[4]{}; - LoadFloatData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); - - // load B column vectors - uint8x8_t bv_packed[NCols]; - const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; - UnrolledLoop([&](size_t i) { - bv_packed[i] = vld1_u8( - reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset - ); - }); - - uint8x8_t bv_u8[NCols][2]; - UnrolledLoop([&](size_t i) { - bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); - bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); - }); - - // shift left 3 and widen to 16 bits - uint16x8_t bv_u16[NCols][2]; - UnrolledLoop([&](size_t i) { - constexpr int shift = 3; - bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); - bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); - }); - - // combine 4 bits with float high half template - UnrolledLoop([&](size_t i) { - bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); - bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); - }); - - // `SubBlkLen` floats of B - float32x4_t bv[NCols][4]; - - // shift left 16, widen to 32 bits, and reinterpret as float - UnrolledLoop([&](size_t i) { - constexpr int shift = 16; - bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); - bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); - - bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); - bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); - }); - - // subtract float conversion offset and zero point - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const float32x4_t offset_v = vdupq_n_f32(offset[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } else { - const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); - UnrolledLoop([&](size_t i) { - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } - - // multiply by scale - UnrolledLoop([&](size_t i) { - const float32x4_t scale_v = vdupq_n_f32(scale[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); - }); - - // c[m,n] += a[m,k] * b[k,n] - UnrolledLoop<4>([&](size_t j) { - UnrolledLoop([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); - }); - } - - // increment pointers to next block - QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - QuantBScale += 1; - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx += 1; - } - } - - if constexpr (NCols == 4) { - float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - - if (BiasPtr != nullptr) { - sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); - } - - vst1q_f32(SumPtr, sum); - } else { - for (size_t i = 0; i < NCols; ++i) { - SumPtr[i] = vaddvq_f32(acc[i]); - if (BiasPtr != nullptr) { - SumPtr[i] += BiasPtr[i]; - } - } - } -} - -template -void -SQ4BitGemmM1Kernel_CompFp32_Impl( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t NCols = 4; - - const float* ARowPtr = A; - float* CRowPtr = C; - - const size_t BlockCountK = BlockStrideQuantB; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - int64_t nblk = static_cast(CountN) - NCols; - - while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompFp32( - BlkLen, - ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - - // move to next `NCols` columns - - QuantBDataColPtr += NCols * StrideQuantBData; - QuantBScaleColPtr += NCols * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NCols : 0; - SumPtr += NCols; - - nblk -= NCols; - } - - // left over columns less than `NCols`? - nblk += NCols; - for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>( - BlkLen, - ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -void -SQ4BitGemmM1Kernel_CompFp32( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_CompFp32_Impl( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_CompFp32_Impl( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } -} - -// Block dequantize a 16 x NCols section of B from column major source to row major destination. -template -MLAS_FORCEINLINE void -Q4BitBlkDequantB_16xNCols( - const std::byte* QuantBDataPtr, - size_t StrideQuantBData, - const float* QuantBColScalePtr, // pointer to NCols scales of adjacent columns - [[maybe_unused]] const float* QuantBColOffsetPtr, // pointer to NCols offsets of adjacent columns - // only used if HasZeroPoint is true - float* DstColPtr -) -{ - const uint8x8_t LowMask = vdup_n_u8(0x0F); - - // load B column vectors - uint8x8_t bv_packed[NCols]; - UnrolledLoop([&](size_t i) { - bv_packed[i] = vld1_u8( - reinterpret_cast(QuantBDataPtr) + i * StrideQuantBData - ); - }); - - uint8x8_t bv_u8[NCols][2]; - UnrolledLoop([&](size_t i) { - bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); - bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); - }); - - // shift left 3 and widen to 16 bits - uint16x8_t bv_u16[NCols][2]; - UnrolledLoop([&](size_t i) { - constexpr int shift = 3; - bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); - bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); - }); - - // combine 4 bits with float high half template - UnrolledLoop([&](size_t i) { - bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); - bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); - }); - - // `SubBlkLen` floats of B - float32x4_t bv[NCols][4]; - - // shift left 16, widen to 32 bits, and reinterpret as float - UnrolledLoop([&](size_t i) { - constexpr int shift = 16; - bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); - bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); - - bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); - bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); - }); - - // subtract float conversion offset and zero point - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const float32x4_t offset_v = vdupq_n_f32(QuantBColOffsetPtr[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } else { - const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); - UnrolledLoop([&](size_t i) { - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } - - // multiply by scale - UnrolledLoop([&](size_t i) { - const float32x4_t scale_v = vdupq_n_f32(QuantBColScalePtr[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); - }); - - // write, transposed, 16 x NCols values - if constexpr (NCols == 4) { - UnrolledLoop<4>([&](size_t j) { - Transpose4x4(bv[0][j], bv[1][j], bv[2][j], bv[3][j]); - - vst1q_f32(&DstColPtr[(j * 4 + 0) * 16], bv[0][j]); - vst1q_f32(&DstColPtr[(j * 4 + 1) * 16], bv[1][j]); - vst1q_f32(&DstColPtr[(j * 4 + 2) * 16], bv[2][j]); - vst1q_f32(&DstColPtr[(j * 4 + 3) * 16], bv[3][j]); - }); - } else { - UnrolledLoop([&](size_t i) { - UnrolledLoop<4>([&](size_t j) { - DstColPtr[(j * 4 + 0) * 16 + i] = vgetq_lane_f32(bv[i][j], 0); - DstColPtr[(j * 4 + 1) * 16 + i] = vgetq_lane_f32(bv[i][j], 1); - DstColPtr[(j * 4 + 2) * 16 + i] = vgetq_lane_f32(bv[i][j], 2); - DstColPtr[(j * 4 + 3) * 16 + i] = vgetq_lane_f32(bv[i][j], 3); - }); - }); - } -} - -template -void -Q4BitBlkDequantBForSgemm_CompFp32_Impl( - size_t BlkLen, - float* FpData, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB -) -{ - constexpr size_t BlkBitWidth = 4; - - float* Dst = FpData; - - const std::byte* QuantBDataCol = QuantBData; - const float* QuantBScaleCol = QuantBScale; - [[maybe_unused]] const std::byte* QuantBZeroPointCol = QuantBZeroPoint; // only used if HasZeroPoint is true - - const size_t StrideQuantBData = BlockStrideQuantB * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - [[maybe_unused]] const size_t StrideQuantBZeroPoint = // only used if HasZeroPoint is true - MlasQNBitZeroPointsForBlksSizeInBytes(BlockStrideQuantB); - - // - // Proceed down 16 column-wide regions of B. Dequantize and write output 16 x 16 elements at a time. - // - - // scales of blocks from 16 adjacent columns - float scale[16]; - // float conversion offsets (including zero point) of blocks from 16 adjacent columns - [[maybe_unused]] float offset[16]; // only used if HasZeroPoint is true - - size_t n_cols_remaining = CountN; - while (n_cols_remaining > 15) { - for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { - for (size_t nn = 0; nn < 16; ++nn) { - scale[nn] = QuantBScaleCol[nn * BlockStrideQuantB + k_blk_idx]; - - if constexpr (HasZeroPoint) { - const std::byte zp_packed = - QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; - const std::byte zp = ((k_blk_idx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[nn] = fp32_conversion::offset + std::to_integer(zp); - } - } - - const size_t kklen = std::min(CountK - k, BlkLen); - - for (size_t kk = 0; kk < kklen; kk += 16) { - constexpr size_t NCols = 4; - - const float* ScalePtr = &scale[0]; - const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; - - float* DstColPtr = Dst; - - for (size_t nn = 0; nn < 16; nn += NCols) { - const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; - - Q4BitBlkDequantB_16xNCols( - QuantBDataPtr, - StrideQuantBData, - ScalePtr, - OffsetPtr, - DstColPtr - ); - - ScalePtr += NCols; - if constexpr (HasZeroPoint) { - OffsetPtr += NCols; - } - DstColPtr += NCols; - } - - Dst += 16 * std::min(kklen - kk, size_t{16}); - } - } - - n_cols_remaining -= 16; - - QuantBDataCol += 16 * StrideQuantBData; - QuantBScaleCol += 16 * BlockStrideQuantB; - if constexpr (HasZeroPoint) { - QuantBZeroPointCol += 16 * StrideQuantBZeroPoint; - } - } - - if (n_cols_remaining > 0) { - for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { - for (size_t nn = 0; nn < n_cols_remaining; ++nn) { - scale[nn] = QuantBScaleCol[nn * BlockStrideQuantB + k_blk_idx]; - - if constexpr (HasZeroPoint) { - const std::byte zp_packed = - QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; - const std::byte zp = ((k_blk_idx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[nn] = fp32_conversion::offset + std::to_integer(zp); - } - } - - const size_t kklen = std::min(CountK - k, BlkLen); - - for (size_t kk = 0; kk < kklen; kk += 16) { - // zero out the 16x16 block in Dst first to ensure zero padding - const float32x4_t zero_v = vdupq_n_f32(0.0f); - UnrolledLoop<16 * 4>([&](size_t i) { - vst1q_f32(Dst + 4 * i, zero_v); - }); - - const float* ScalePtr = &scale[0]; - const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; - - float* DstColPtr = Dst; - - for (size_t nn = 0; nn < n_cols_remaining; ++nn) { - const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; - - Q4BitBlkDequantB_16xNCols<1, HasZeroPoint>( - QuantBDataPtr, - StrideQuantBData, - ScalePtr, - OffsetPtr, - DstColPtr - ); - - ScalePtr += 1; - if constexpr (HasZeroPoint) { - OffsetPtr += 1; - } - DstColPtr += 1; - } - - Dst += 16 * std::min(kklen - kk, size_t{16}); - } - } - } -} - -void -Q4BitBlkDequantBForSgemm_CompFp32( - size_t BlkLen, - float* FpData, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB -) -{ - if (QuantBZeroPoint != nullptr) { - Q4BitBlkDequantBForSgemm_CompFp32_Impl( - BlkLen, - FpData, - QuantBData, - QuantBScale, - QuantBZeroPoint, - CountN, - CountK, - BlockStrideQuantB - ); - } else { - Q4BitBlkDequantBForSgemm_CompFp32_Impl( - BlkLen, - FpData, - QuantBData, - QuantBScale, - QuantBZeroPoint, - CountN, - CountK, - BlockStrideQuantB - ); - } -} - -// -// CompInt8 kernel implementation. -// - -template -MLAS_FORCEINLINE void -QuantizeBlock( - size_t BlkLen, - const float* A, - size_t ElementCount, - std::byte* QuantA -) -{ - static_assert(SubBlkLen >= 16 && SubBlkLen % 16 == 0); - - assert(BlkLen % SubBlkLen == 0); - - // - // Scan block values first to determine scale. - // - - float amax = 0.0f; // max of absolute values of A block - - size_t k; - for (k = 0; k < ElementCount; k += SubBlkLen) { - const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - - float32x4_t a[SubBlkLen / 4]{}; - LoadFloatData(A + k, SubBlkElementCount, a); - - float32x4_t abs_a[SubBlkLen / 4]; - UnrolledLoop([&](size_t i) { - abs_a[i] = vabsq_f32(a[i]); - }); - - // find amax of SubBlkLen elements - for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) { - for (size_t i = 0; i < interval; ++i) { - abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); - } - } - - // update existing amax - amax = std::max(amax, vmaxvq_f32(abs_a[0])); - } - - constexpr float range_max = (1 << 7) - 1; - const float scale = amax / range_max; - const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; - - Q8BlkScale(QuantA) = scale; - - // - // Compute quantized block values. - // - - int8_t* QuantAData = Q8BlkData(QuantA); - - for (k = 0; k < ElementCount; k += SubBlkLen) { - const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - - float32x4_t a[SubBlkLen / 4]{}; - LoadFloatData(A + k, SubBlkElementCount, a); - - UnrolledLoop([&](size_t i) { - a[i] = vmulq_n_f32(a[i], scale_reciprocal); - }); - - int32x4_t a_s32[SubBlkLen / 4]; - UnrolledLoop([&](size_t i) { - a_s32[i] = vcvtaq_s32_f32(a[i]); - }); - - UnrolledLoop([&](size_t i) { - QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); - QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); - QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); - QuantAData[k + i * 4 + 3] = static_cast(vgetq_lane_s32(a_s32[i], 3)); - }); - } - - // - // Zero out any remaining sub-block elements. - // - - for (; k < BlkLen; k += SubBlkLen) { - const int8x16_t Zeros = vdupq_n_s8(0); - UnrolledLoop([&](size_t i) { - vst1q_s8(QuantAData + k + i * 16, Zeros); - }); - } -} - -void -QuantizeARow_CompInt8( - size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA -) -{ - const float* ADataBlkPtr = A; - std::byte* QuantABlkPtr = QuantA; - - for (size_t k = 0; k < CountK; k += BlkLen) { - const size_t k_blk_len = std::min(CountK - k, BlkLen); - - QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr); - - ADataBlkPtr += BlkLen; - QuantABlkPtr += Q8BlkSize(BlkLen); - } -} - -template -void -SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen16( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t BlkLen = 16; - - float* CRowPtr = C; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); - - for (size_t n = 0; n < CountN; ++n) { - const std::byte* QuantAPtr = QuantA; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc0{}, acc1{}; - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= 2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); - const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] & std::byte{0x0F}) : 8 - ); - const int8x16_t bzp1 = vdupq_n_s8( - HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] >> 4) : 8 - ); - - // load A - const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); - const int8x16_t av1 = vld1q_s8(Q8BlkData(QuantABlk1)); - - // load B - const uint8x16_t bv_packed01 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - - const uint8x16_t bv_lo01 = vandq_u8(bv_packed01, LowMaskU8x16); - const uint8x16_t bv_hi01 = vshrq_n_u8(bv_packed01, 4); - - int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(vget_low_u8(bv_lo01), vget_low_u8(bv_hi01))); - int8x16_t bv1 = vreinterpretq_s8_u8(vcombine_u8(vget_high_u8(bv_lo01), vget_high_u8(bv_hi01))); - - // subtract B zero point - bv0 = vsubq_s8(bv0, bzp0); - bv1 = vsubq_s8(bv1, bzp1); - - // quantized dot product - const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); - const int32x4_t dot1 = vdotq_s32(vdupq_n_s32(0), av1, bv1); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); - - // increment block pointers - - QuantAPtr += Q8BlkSize(BlkLen) * 2; - QuantBDataPtr += 8 * 2; - QuantBScalePtr += 2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - if (k_blks_remaining > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] & std::byte{0x0F}) : 8 - ); - - // load A - const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); - - // load B - const uint8x8_t bv_packed0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); - - const uint8x8_t bv_lo0 = vand_u8(bv_packed0, LowMaskU8x8); - const uint8x8_t bv_hi0 = vshr_n_u8(bv_packed0, 4); - - int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(bv_lo0, bv_hi0)); - - // subtract B zero point - bv0 = vsubq_s8(bv0, bzp0); - - // quantized dot product - const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - } - - *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -template -void -SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t BlkLen = 32; - - float* CRowPtr = C; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - - for (size_t n = 0; n < CountN; ++n) { - const std::byte* QuantAPtr = QuantA; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc0{}, acc1{}; - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= 2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); - const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 - ); - const int8x16_t bzp1 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) >> 4) : 8 - ); - - // load A - const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); - const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); - const int8x16_t av_lo1 = vld1q_s8(Q8BlkData(QuantABlk1)); - const int8x16_t av_hi1 = vld1q_s8(Q8BlkData(QuantABlk1) + 16); - - // load B - const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); - - int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); - int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); - int8x16_t bv_lo1 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); - int8x16_t bv_hi1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); - - // subtract B zero point - bv_lo0 = vsubq_s8(bv_lo0, bzp0); - bv_hi0 = vsubq_s8(bv_hi0, bzp0); - bv_lo1 = vsubq_s8(bv_lo1, bzp1); - bv_hi1 = vsubq_s8(bv_hi1, bzp1); - - // quantized dot product - int32x4_t dot0{}, dot1{}; - dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); - dot1 = vdotq_s32(vdotq_s32(dot1, av_lo1, bv_lo1), av_hi1, bv_hi1); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); - - // increment block pointers - - QuantAPtr += Q8BlkSize(BlkLen) * 2; - QuantBDataPtr += 16 * 2; - QuantBScalePtr += 2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - if (k_blks_remaining > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 - ); - - // load A - const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); - const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); - - // load B - const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - - int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); - int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); - - // subtract B zero point - bv_lo0 = vsubq_s8(bv_lo0, bzp0); - bv_hi0 = vsubq_s8(bv_hi0, bzp0); - - // quantized dot product - int32x4_t dot0{}; - dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - } - - *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -template -void -SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLenGreaterThan32( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - - assert(BlkLen > 32); - assert(BlkLen % 32 == 0); - - float* CRowPtr = C; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - - // process blocks in 32-element sub-blocks - const size_t SubBlksPerBlk = BlkLen / 32; - - for (size_t n = 0; n < CountN; ++n) { - const std::byte* QuantAPtr = QuantA; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc0{}, acc1{}; - - for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { - // compute combined scale - const float32x4_t scale = vdupq_n_f32(Q8BlkScale(QuantAPtr) * (*QuantBScalePtr)); - - // load B zero point - const int8x16_t bzp = [&]() -> int8x16_t { - if constexpr (HasZeroPoint) { - return vdupq_n_s8( - ((k_blk_idx & 1) == 0) ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) - : std::to_integer((*QuantBZeroPointPtr) >> 4) - ); - } else { - return vdupq_n_s8(8); - } - }(); - - const int8_t* QuantADataPtr = Q8BlkData(QuantAPtr); - - for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; sub_blk_idx += 2) { - // load A - const int8x16_t av0 = vld1q_s8(QuantADataPtr + 0); - const int8x16_t av1 = vld1q_s8(QuantADataPtr + 16); - const int8x16_t av2 = vld1q_s8(QuantADataPtr + 32); - const int8x16_t av3 = vld1q_s8(QuantADataPtr + 48); - - // load B - const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); - - int8x16_t bv0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); - int8x16_t bv1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); - int8x16_t bv2 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); - int8x16_t bv3 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); - - // subtract B zero point - bv0 = vsubq_s8(bv0, bzp); - bv1 = vsubq_s8(bv1, bzp); - bv2 = vsubq_s8(bv2, bzp); - bv3 = vsubq_s8(bv3, bzp); - - // quantized dot product - int32x4_t dot0{}, dot1{}; - dot0 = vdotq_s32(vdotq_s32(dot0, av0, bv0), av1, bv1); - dot1 = vdotq_s32(vdotq_s32(dot1, av2, bv2), av3, bv3); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale); - acc1 = vfmaq_f32(acc1, dot_f32_1, scale); - - // increment block data pointers to next sub-block - QuantADataPtr += 16 * 4; - QuantBDataPtr += 16 * 2; - } - - // increment other block pointers - - QuantAPtr += Q8BlkSize(BlkLen); - QuantBScalePtr += 1; - - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; - } - } - - *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -template -MLAS_FORCEINLINE void -SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (BlkLen == 16) { - SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen16( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLenGreaterThan32( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } -} - -void -SQ4BitGemmM1Kernel_CompInt8( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t /*CountK*/, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } -} - -} // namespace +} // namespace sqnbitgemm_neon // // Kernel dispatch structure definition. @@ -1485,17 +178,17 @@ SQ4BitGemmM1Kernel_CompInt8( const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { MLAS_SQNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataSize = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.SQ4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceAlignment; - d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32; + d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32; + d.Q4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::Q4BitBlkDequantBForSgemm_CompFp32; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; + d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h new file mode 100644 index 000000000000..ef9345d7ac48 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h @@ -0,0 +1,144 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon.h + +Abstract: + + This module includes function declarations and common helper functions for + SQNBitGemm ARM NEON kernels. + +--*/ + +#pragma once + +#include + +#include +#include +#include + +#include "mlasi.h" + +namespace sqnbitgemm_neon +{ + +// +// Function declarations for SQNBitGemm ARM NEON kernel entry points. +// Refer to the prototypes in sqnbitgemm.h for documentation. +// These are declared here so they can be used to initialize the +// MLAS_SQNBIT_GEMM_DISPATCH structure and also be implemented in separate +// files. +// + +// CompFp32 declarations + +void +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias +); + +void +Q4BitBlkDequantBForSgemm_CompFp32( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockCountK +); + +// CompInt8 declarations + +void +QuantizeARow_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +); + +size_t +SQ4BitGemmKernel_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + size_t ldc, + const float* Bias +); + +// +// General helpers. +// + +template +MLAS_FORCEINLINE void +UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) +{ + (f(Indices), ...); +} + +template +MLAS_FORCEINLINE void +UnrolledLoop(IterationFn&& f) +{ + UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); +} + +template +MLAS_FORCEINLINE void +LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) +{ + static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); + + assert(count <= Capacity); + + size_t vi = 0; // vector index + + // handle 4 values at a time + while (count > 3) { + dst[vi] = vld1q_f32(src); + + vi += 1; + src += 4; + count -= 4; + } + + // handle remaining values + if (count > 0) { + dst[vi] = vsetq_lane_f32(src[0], dst[vi], 0); + + if (count > 1) { + dst[vi] = vsetq_lane_f32(src[1], dst[vi], 1); + + if (count > 2) { + dst[vi] = vsetq_lane_f32(src[2], dst[vi], 2); + } + } + } +} + +} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp new file mode 100644 index 000000000000..ca64ebe3b113 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp @@ -0,0 +1,646 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon_fp32.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompFp32. + +--*/ + +#include + +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_neon.h" + +namespace sqnbitgemm_neon +{ + +namespace +{ + +// +// CompFp32 kernel implementation. +// + +MLAS_FORCEINLINE void +Transpose4x4(float32x4_t& a0, float32x4_t& a1, float32x4_t& a2, float32x4_t& a3) +{ + // aN: aN_0 aN_1 aN_2 aN_3 + + float32x4_t b0 = vzip1q_f32(a0, a1); // a0_0 a1_0 a0_1 a1_1 + float32x4_t b1 = vzip2q_f32(a0, a1); // a0_2 a1_2 a0_3 a1_3 + float32x4_t b2 = vzip1q_f32(a2, a3); // a2_0 a3_0 a2_1 a3_1 + float32x4_t b3 = vzip2q_f32(a2, a3); // a2_2 a3_2 a2_3 a3_3 + + // a0_0 a1_0 a2_0 a3_0 + a0 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_1 a1_1 a2_1 a3_1 + a1 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_2 a1_2 a3_2 a3_2 + a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); + // a0_3 a1_3 a2_3 a3_3 + a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); +} + +MLAS_FORCEINLINE float32x4_t +FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) +{ + Transpose4x4(a0, a1, a2, a3); + return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); +} + +namespace fp32_conversion +{ + +// Manual conversion to float takes place in two steps: +// 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f]. +// This target float range is convenient because the 4-bit source values can be placed directly into the +// target float bits. +// 2. Subtract the conversion offset of 16 from the float result. + +// The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values. +constexpr uint16_t float_high_half_template = 0b0'10000011'0000000; +// sign|exponent|partial mantissa +// +|131: 2^4|~~~~ <- 4 bits go here + +const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template); + +constexpr float offset = 16.0f; + +} // namespace fp32_conversion + +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkBitWidth4_CompFp32( + size_t BlkLen, + const float* ARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkLen = 16; + + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + + assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); + + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + float32x4_t acc[NCols]{}; + + const std::byte* QuantBData = QuantBDataColPtr; + const float* QuantBScale = QuantBScaleColPtr; + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint is true + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + float scale[NCols]; + UnrolledLoop( + [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } + ); + + [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset. + // only used if HasZeroPoint is true + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = fp32_conversion::offset + std::to_integer(zp); + }); + } + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { + // load A row vector elements + + // load `SubBlkLen` elements from A, padded with 0's if there aren't enough + const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); + float32x4_t av[4]{}; + LoadFloatData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); + }); + + uint8x8_t bv_u8[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); + }); + + // shift left 3 and widen to 16 bits + uint16x8_t bv_u16[NCols][2]; + UnrolledLoop([&](size_t i) { + constexpr int shift = 3; + bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); + bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); + }); + + // combine 4 bits with float high half template + UnrolledLoop([&](size_t i) { + bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); + bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); + }); + + // `SubBlkLen` floats of B + float32x4_t bv[NCols][4]; + + // shift left 16, widen to 32 bits, and reinterpret as float + UnrolledLoop([&](size_t i) { + constexpr int shift = 16; + bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); + bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); + + bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); + bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); + }); + + // subtract float conversion offset and zero point + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(offset[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } else { + const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } + + // multiply by scale + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(scale[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); + }); + + // c[m,n] += a[m,k] * b[k,n] + UnrolledLoop<4>([&](size_t j) { + UnrolledLoop([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); + }); + } + + // increment pointers to next block + QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScale += 1; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } + + if constexpr (NCols == 4) { + float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + + if (BiasPtr != nullptr) { + sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); + } + + vst1q_f32(SumPtr, sum); + } else { + for (size_t i = 0; i < NCols; ++i) { + SumPtr[i] = vaddvq_f32(acc[i]); + if (BiasPtr != nullptr) { + SumPtr[i] += BiasPtr[i]; + } + } + } +} + +template +void +SQ4BitGemmM1Kernel_CompFp32_Impl( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t NCols = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { + ComputeDotProducts_BlkBitWidth4_CompFp32( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + + nblk -= NCols; + } + + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +} // namespace + +void +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + constexpr bool HasZeroPoint = true; + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockCountK, + Bias + ); + } else { + constexpr bool HasZeroPoint = false; + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockCountK, + Bias + ); + } +} + +namespace +{ + +// Block dequantize a 16 x NCols section of B from column major source to row major destination. +template +MLAS_FORCEINLINE void +Q4BitBlkDequantB_16xNCols( + const std::byte* QuantBDataPtr, + size_t StrideQuantBData, + const float* QuantBColScalePtr, // pointer to NCols scales of adjacent columns + [[maybe_unused]] const float* QuantBColOffsetPtr, // pointer to NCols offsets of adjacent columns + // only used if HasZeroPoint is true + float* DstColPtr +) +{ + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBDataPtr) + i * StrideQuantBData + ); + }); + + uint8x8_t bv_u8[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); + }); + + // shift left 3 and widen to 16 bits + uint16x8_t bv_u16[NCols][2]; + UnrolledLoop([&](size_t i) { + constexpr int shift = 3; + bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); + bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); + }); + + // combine 4 bits with float high half template + UnrolledLoop([&](size_t i) { + bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); + bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); + }); + + // `SubBlkLen` floats of B + float32x4_t bv[NCols][4]; + + // shift left 16, widen to 32 bits, and reinterpret as float + UnrolledLoop([&](size_t i) { + constexpr int shift = 16; + bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); + bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); + + bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); + bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); + }); + + // subtract float conversion offset and zero point + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(QuantBColOffsetPtr[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } else { + const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } + + // multiply by scale + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(QuantBColScalePtr[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); + }); + + // write, transposed, 16 x NCols values + if constexpr (NCols == 4) { + UnrolledLoop<4>([&](size_t j) { + Transpose4x4(bv[0][j], bv[1][j], bv[2][j], bv[3][j]); + + vst1q_f32(&DstColPtr[(j * 4 + 0) * 16], bv[0][j]); + vst1q_f32(&DstColPtr[(j * 4 + 1) * 16], bv[1][j]); + vst1q_f32(&DstColPtr[(j * 4 + 2) * 16], bv[2][j]); + vst1q_f32(&DstColPtr[(j * 4 + 3) * 16], bv[3][j]); + }); + } else { + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { + DstColPtr[(j * 4 + 0) * 16 + i] = vgetq_lane_f32(bv[i][j], 0); + DstColPtr[(j * 4 + 1) * 16 + i] = vgetq_lane_f32(bv[i][j], 1); + DstColPtr[(j * 4 + 2) * 16 + i] = vgetq_lane_f32(bv[i][j], 2); + DstColPtr[(j * 4 + 3) * 16 + i] = vgetq_lane_f32(bv[i][j], 3); + }); + }); + } +} + +template +void +Q4BitBlkDequantBForSgemm_CompFp32_Impl( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockCountK +) +{ + constexpr size_t BlkBitWidth = 4; + + float* Dst = FpData; + + const std::byte* QuantBDataCol = QuantBData; + const float* QuantBScaleCol = QuantBScale; + [[maybe_unused]] const std::byte* QuantBZeroPointCol = QuantBZeroPoint; // only used if HasZeroPoint is true + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + [[maybe_unused]] const size_t StrideQuantBZeroPoint = // only used if HasZeroPoint is true + MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + // + // Proceed down 16 column-wide regions of B. Dequantize and write output 16 x 16 elements at a time. + // + + // scales of blocks from 16 adjacent columns + float scale[16]; + // float conversion offsets (including zero point) of blocks from 16 adjacent columns + [[maybe_unused]] float offset[16]; // only used if HasZeroPoint is true + + size_t n_cols_remaining = CountN; + while (n_cols_remaining > 15) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { + for (size_t nn = 0; nn < 16; ++nn) { + scale[nn] = QuantBScaleCol[nn * BlockCountK + k_blk_idx]; + + if constexpr (HasZeroPoint) { + const std::byte zp_packed = + QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; + const std::byte zp = ((k_blk_idx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[nn] = fp32_conversion::offset + std::to_integer(zp); + } + } + + const size_t kklen = std::min(CountK - k, BlkLen); + + for (size_t kk = 0; kk < kklen; kk += 16) { + constexpr size_t NCols = 4; + + const float* ScalePtr = &scale[0]; + const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; + + float* DstColPtr = Dst; + + for (size_t nn = 0; nn < 16; nn += NCols) { + const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; + + Q4BitBlkDequantB_16xNCols( + QuantBDataPtr, + StrideQuantBData, + ScalePtr, + OffsetPtr, + DstColPtr + ); + + ScalePtr += NCols; + if constexpr (HasZeroPoint) { + OffsetPtr += NCols; + } + DstColPtr += NCols; + } + + Dst += 16 * std::min(kklen - kk, size_t{16}); + } + } + + n_cols_remaining -= 16; + + QuantBDataCol += 16 * StrideQuantBData; + QuantBScaleCol += 16 * BlockCountK; + if constexpr (HasZeroPoint) { + QuantBZeroPointCol += 16 * StrideQuantBZeroPoint; + } + } + + if (n_cols_remaining > 0) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { + for (size_t nn = 0; nn < n_cols_remaining; ++nn) { + scale[nn] = QuantBScaleCol[nn * BlockCountK + k_blk_idx]; + + if constexpr (HasZeroPoint) { + const std::byte zp_packed = + QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; + const std::byte zp = ((k_blk_idx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[nn] = fp32_conversion::offset + std::to_integer(zp); + } + } + + const size_t kklen = std::min(CountK - k, BlkLen); + + for (size_t kk = 0; kk < kklen; kk += 16) { + // zero out the 16x16 block in Dst first to ensure zero padding + const float32x4_t zero_v = vdupq_n_f32(0.0f); + UnrolledLoop<16 * 4>([&](size_t i) { + vst1q_f32(Dst + 4 * i, zero_v); + }); + + const float* ScalePtr = &scale[0]; + const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; + + float* DstColPtr = Dst; + + for (size_t nn = 0; nn < n_cols_remaining; ++nn) { + const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; + + Q4BitBlkDequantB_16xNCols<1, HasZeroPoint>( + QuantBDataPtr, + StrideQuantBData, + ScalePtr, + OffsetPtr, + DstColPtr + ); + + ScalePtr += 1; + if constexpr (HasZeroPoint) { + OffsetPtr += 1; + } + DstColPtr += 1; + } + + Dst += 16 * std::min(kklen - kk, size_t{16}); + } + } + } +} + +} // namespace + +void +Q4BitBlkDequantBForSgemm_CompFp32( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockCountK +) +{ + if (QuantBZeroPoint != nullptr) { + Q4BitBlkDequantBForSgemm_CompFp32_Impl( + BlkLen, + FpData, + QuantBData, + QuantBScale, + QuantBZeroPoint, + CountN, + CountK, + BlockCountK + ); + } else { + Q4BitBlkDequantBForSgemm_CompFp32_Impl( + BlkLen, + FpData, + QuantBData, + QuantBScale, + QuantBZeroPoint, + CountN, + CountK, + BlockCountK + ); + } +} + +} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp new file mode 100644 index 000000000000..db3b9ee65659 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -0,0 +1,1315 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon_int8.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompInt8. + +--*/ + +#include + +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_neon.h" +#include "sqnbitgemm_q8_block.h" + +namespace sqnbitgemm_neon +{ + +// +// CompInt8 kernel implementation. +// + +namespace +{ + +template +MLAS_FORCEINLINE void +QuantizeBlock( + size_t BlkLen, + const float* A, + size_t ElementCount, + std::byte* QuantA +) +{ + static_assert(SubBlkLen >= 16 && SubBlkLen % 16 == 0); + + assert(BlkLen % SubBlkLen == 0); + + // + // Scan block values first to determine scale. + // + + float amax = 0.0f; // max of absolute values of A block + + size_t k; + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + + float32x4_t a[SubBlkLen / 4]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + + float32x4_t abs_a[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { + abs_a[i] = vabsq_f32(a[i]); + }); + + // find amax of SubBlkLen elements + for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) { + for (size_t i = 0; i < interval; ++i) { + abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); + } + } + + // update existing amax + amax = std::max(amax, vmaxvq_f32(abs_a[0])); + } + + constexpr float range_max = (1 << 7) - 1; + const float scale = amax / range_max; + const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; + + Q8BlkScale(QuantA) = scale; + + // + // Compute quantized block values. + // + + int8_t* QuantAData = Q8BlkData(QuantA); + + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + + float32x4_t a[SubBlkLen / 4]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + + UnrolledLoop([&](size_t i) { + a[i] = vmulq_n_f32(a[i], scale_reciprocal); + }); + + int32x4_t a_s32[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { + a_s32[i] = vcvtaq_s32_f32(a[i]); + }); + + UnrolledLoop([&](size_t i) { + QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); + QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); + QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); + QuantAData[k + i * 4 + 3] = static_cast(vgetq_lane_s32(a_s32[i], 3)); + }); + } + + // + // Zero out any remaining sub-block elements. + // + + for (; k < BlkLen; k += SubBlkLen) { + const int8x16_t Zeros = vdupq_n_s8(0); + UnrolledLoop([&](size_t i) { + vst1q_s8(QuantAData + k + i * 16, Zeros); + }); + } +} + +} // namespace + +void +QuantizeARow_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +) +{ + const float* ADataBlkPtr = A; + std::byte* QuantABlkPtr = QuantA; + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr); + + ADataBlkPtr += BlkLen; + QuantABlkPtr += Q8BlkSize(BlkLen); + } +} + +namespace +{ + +// +// The ComputeRxC functions compute an R row by C column tile of the output matrix. +// + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute2x2_BlkLen16( + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK, + size_t StrideQuantA, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + size_t ldc +) +{ + constexpr size_t BlkLen = 16; + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc00{}, acc01{}, acc10{}, acc11{}; + + for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { + const std::byte* QuantABlkRow0 = QuantAPtr; + const std::byte* QuantABlkRow1 = QuantAPtr + StrideQuantA; + + const float QuantBScaleCol0 = *QuantBScalePtr; + const float QuantBScaleCol1 = *(QuantBScalePtr + StrideQuantBScale); + + // compute combined scales + const float scale00 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol0; + const float scale01 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol1; + const float scale10 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol0; + const float scale11 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol1; + + // load B zero point + int8_t bzp_col0; + int8_t bzp_col1; + if constexpr (HasZeroPoint) { + const std::byte QuantBZeroPointByteCol0 = *QuantBZeroPointPtr; + const std::byte QuantBZeroPointByteCol1 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint); + if ((k_blk_idx & 1) == 0) { + bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 & std::byte{0x0F}); + bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 & std::byte{0x0F}); + } else { + bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 >> 4); + bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 >> 4); + } + } else { + bzp_col0 = 8; + bzp_col1 = 8; + } + + const int8_t* QuantADataPtrRow0 = Q8BlkData(QuantABlkRow0); + const int8_t* QuantADataPtrRow1 = Q8BlkData(QuantABlkRow1); + + // TODO handling only 16 elements per accumulator at a time here, probably can do better + { + // load A + const int8x16_t av_row0 = vld1q_s8(QuantADataPtrRow0 + 0); + const int8x16_t av_row1 = vld1q_s8(QuantADataPtrRow1 + 0); + + // load B + const uint8x8_t bv_packed_col0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x8_t bv_packed_col1 = vld1_u8(reinterpret_cast(QuantBDataPtr) + StrideQuantBData); + + const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); + + int8x16_t bv_col0 = vreinterpretq_s8_u8( + vcombine_u8( + vand_u8(bv_packed_col0, LowMaskU8x8), + vshr_n_u8(bv_packed_col0, 4) + ) + ); + int8x16_t bv_col1 = vreinterpretq_s8_u8( + vcombine_u8( + vand_u8(bv_packed_col1, LowMaskU8x8), + vshr_n_u8(bv_packed_col1, 4) + ) + ); + + // subtract B zero point + bv_col0 = vsubq_s8(bv_col0, vdupq_n_s8(bzp_col0)); + bv_col1 = vsubq_s8(bv_col1, vdupq_n_s8(bzp_col1)); + + // quantized dot product + int32x4_t dot00{}, dot01{}, dot10{}, dot11{}; + dot00 = vdotq_s32(dot00, av_row0, bv_col0); + dot01 = vdotq_s32(dot01, av_row0, bv_col1); + dot10 = vdotq_s32(dot10, av_row1, bv_col0); + dot11 = vdotq_s32(dot11, av_row1, bv_col1); + + // convert to float + const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00); + const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01); + const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10); + const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11); + + // multiply by scale and update accumulator + acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00)); + acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01)); + acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10)); + acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11)); + } + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen); + QuantBDataPtr += 8; + QuantBScalePtr += 1; + + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; + } + } + + SumPtr[0] = vaddvq_f32(acc00); + SumPtr[1] = vaddvq_f32(acc01); + SumPtr[ldc + 0] = vaddvq_f32(acc10); + SumPtr[ldc + 1] = vaddvq_f32(acc11); + + if (BiasPtr != nullptr) { + SumPtr[0] += BiasPtr[0]; + SumPtr[1] += BiasPtr[1]; + SumPtr[ldc + 0] += BiasPtr[0]; + SumPtr[ldc + 1] += BiasPtr[1]; + } +} + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16( + size_t BlkLen, + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK, + size_t StrideQuantA, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + size_t ldc +) +{ + // process blocks in 32-element sub-blocks + const size_t SubBlksPerBlk = BlkLen / 32; + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc00{}, acc01{}, acc10{}, acc11{}; + + for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { + const std::byte* QuantABlkRow0 = QuantAPtr; + const std::byte* QuantABlkRow1 = QuantAPtr + StrideQuantA; + + const float QuantBScaleCol0 = *QuantBScalePtr; + const float QuantBScaleCol1 = *(QuantBScalePtr + StrideQuantBScale); + + // compute combined scales + const float scale00 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol0; + const float scale01 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol1; + const float scale10 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol0; + const float scale11 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol1; + + // load B zero point + int8_t bzp_col0; + int8_t bzp_col1; + if constexpr (HasZeroPoint) { + const std::byte QuantBZeroPointByteCol0 = *QuantBZeroPointPtr; + const std::byte QuantBZeroPointByteCol1 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint); + if ((k_blk_idx & 1) == 0) { + bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 & std::byte{0x0F}); + bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 & std::byte{0x0F}); + } else { + bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 >> 4); + bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 >> 4); + } + } else { + bzp_col0 = 8; + bzp_col1 = 8; + } + + const int8_t* QuantADataPtrRow0 = Q8BlkData(QuantABlkRow0); + const int8_t* QuantADataPtrRow1 = Q8BlkData(QuantABlkRow1); + + for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; ++sub_blk_idx) { + // load A + const int8x16_t av_row0_0 = vld1q_s8(QuantADataPtrRow0 + 0); + const int8x16_t av_row0_1 = vld1q_s8(QuantADataPtrRow0 + 16); + const int8x16_t av_row1_0 = vld1q_s8(QuantADataPtrRow1 + 0); + const int8x16_t av_row1_1 = vld1q_s8(QuantADataPtrRow1 + 16); + + // load B + const uint8x16_t bv_packed_col0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed_col1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + StrideQuantBData); + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + int8x16_t bv_col0_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col0, LowMaskU8x16)); + int8x16_t bv_col0_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col0, 4)); + int8x16_t bv_col1_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col1, LowMaskU8x16)); + int8x16_t bv_col1_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col1, 4)); + + // subtract B zero point + bv_col0_0 = vsubq_s8(bv_col0_0, vdupq_n_s8(bzp_col0)); + bv_col0_1 = vsubq_s8(bv_col0_1, vdupq_n_s8(bzp_col0)); + bv_col1_0 = vsubq_s8(bv_col1_0, vdupq_n_s8(bzp_col1)); + bv_col1_1 = vsubq_s8(bv_col1_1, vdupq_n_s8(bzp_col1)); + + // quantized dot product + int32x4_t dot00{}, dot01{}, dot10{}, dot11{}; + dot00 = vdotq_s32(vdotq_s32(dot00, av_row0_0, bv_col0_0), av_row0_1, bv_col0_1); + dot01 = vdotq_s32(vdotq_s32(dot01, av_row0_0, bv_col1_0), av_row0_1, bv_col1_1); + dot10 = vdotq_s32(vdotq_s32(dot10, av_row1_0, bv_col0_0), av_row1_1, bv_col0_1); + dot11 = vdotq_s32(vdotq_s32(dot11, av_row1_0, bv_col1_0), av_row1_1, bv_col1_1); + + // convert to float + const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00); + const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01); + const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10); + const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11); + + // multiply by scale and update accumulator + acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00)); + acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01)); + acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10)); + acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11)); + + // increment block data pointers to next sub-block + QuantADataPtrRow0 += 32; + QuantADataPtrRow1 += 32; + QuantBDataPtr += 16; + } + + // increment other block pointers + + QuantAPtr += Q8BlkSize(BlkLen); + QuantBScalePtr += 1; + + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; + } + } + + SumPtr[0] = vaddvq_f32(acc00); + SumPtr[1] = vaddvq_f32(acc01); + SumPtr[ldc + 0] = vaddvq_f32(acc10); + SumPtr[ldc + 1] = vaddvq_f32(acc11); + + if (BiasPtr != nullptr) { + SumPtr[0] += BiasPtr[0]; + SumPtr[1] += BiasPtr[1]; + SumPtr[ldc + 0] += BiasPtr[0]; + SumPtr[ldc + 1] += BiasPtr[1]; + } +} + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK +) +{ + constexpr size_t BlkLen = 16; + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + const int8x16_t bzp1 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) >> 4) : 8 + ); + + // load A + const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av1 = vld1q_s8(Q8BlkData(QuantABlk1)); + + // load B + const uint8x16_t bv_packed01 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + const uint8x16_t bv_lo01 = vandq_u8(bv_packed01, LowMaskU8x16); + const uint8x16_t bv_hi01 = vshrq_n_u8(bv_packed01, 4); + + int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(vget_low_u8(bv_lo01), vget_low_u8(bv_hi01))); + int8x16_t bv1 = vreinterpretq_s8_u8(vcombine_u8(vget_high_u8(bv_lo01), vget_high_u8(bv_hi01))); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp0); + bv1 = vsubq_s8(bv1, bzp1); + + // quantized dot product + const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); + const int32x4_t dot1 = vdotq_s32(vdupq_n_s32(0), av1, bv1); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 8 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + + // load A + const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); + + // load B + const uint8x8_t bv_packed0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); + + const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); + + const uint8x8_t bv_lo0 = vand_u8(bv_packed0, LowMaskU8x8); + const uint8x8_t bv_hi0 = vshr_n_u8(bv_packed0, 4); + + int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(bv_lo0, bv_hi0)); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp0); + + // quantized dot product + const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } +} + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK +) +{ + constexpr size_t BlkLen = 32; + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + const int8x16_t bzp1 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) >> 4) : 8 + ); + + // load A + const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); + const int8x16_t av_lo1 = vld1q_s8(Q8BlkData(QuantABlk1)); + const int8x16_t av_hi1 = vld1q_s8(Q8BlkData(QuantABlk1) + 16); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); + + int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + int8x16_t bv_lo1 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); + int8x16_t bv_hi1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); + + // subtract B zero point + bv_lo0 = vsubq_s8(bv_lo0, bzp0); + bv_hi0 = vsubq_s8(bv_hi0, bzp0); + bv_lo1 = vsubq_s8(bv_lo1, bzp1); + bv_hi1 = vsubq_s8(bv_hi1, bzp1); + + // quantized dot product + int32x4_t dot0{}, dot1{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); + dot1 = vdotq_s32(vdotq_s32(dot1, av_lo1, bv_lo1), av_hi1, bv_hi1); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + + // load A + const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + + int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + + // subtract B zero point + bv_lo0 = vsubq_s8(bv_lo0, bzp0); + bv_hi0 = vsubq_s8(bv_hi0, bzp0); + + // quantized dot product + int32x4_t dot0{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } +} + +template +MLAS_FORCEINLINE void +SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( + size_t BlkLen, + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + const float* BiasPtr, + float* SumPtr, + size_t BlockCountK +) +{ + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + // process blocks in 32-element sub-blocks + const size_t SubBlksPerBlk = BlkLen / 32; + + const std::byte* QuantAPtr = QuantARowPtr; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { + const std::byte* QuantABlk0 = QuantAPtr; + + // compute combined scale + const float32x4_t scale = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + + // load B zero point + const int8x16_t bzp = [&]() -> int8x16_t { + if constexpr (HasZeroPoint) { + return vdupq_n_s8( + ((k_blk_idx & 1) == 0) ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) + : std::to_integer((*QuantBZeroPointPtr) >> 4) + ); + } else { + return vdupq_n_s8(8); + } + }(); + + const int8_t* QuantADataPtr = Q8BlkData(QuantAPtr); + + for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; sub_blk_idx += 2) { + // load A + const int8x16_t av0 = vld1q_s8(QuantADataPtr + 0); + const int8x16_t av1 = vld1q_s8(QuantADataPtr + 16); + const int8x16_t av2 = vld1q_s8(QuantADataPtr + 32); + const int8x16_t av3 = vld1q_s8(QuantADataPtr + 48); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); + + int8x16_t bv0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + int8x16_t bv2 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); + int8x16_t bv3 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp); + bv1 = vsubq_s8(bv1, bzp); + bv2 = vsubq_s8(bv2, bzp); + bv3 = vsubq_s8(bv3, bzp); + + // quantized dot product + int32x4_t dot0{}, dot1{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av0, bv0), av1, bv1); + dot1 = vdotq_s32(vdotq_s32(dot1, av2, bv2), av3, bv3); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale); + + // increment block data pointers to next sub-block + QuantADataPtr += 16 * 4; + QuantBDataPtr += 16 * 2; + } + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen); + QuantBScalePtr += 1; + + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; + } + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } +} + +template +MLAS_FORCEINLINE void +AdvanceColPtrs( + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const std::byte*& QuantBDataColPtr, + const float*& QuantBScaleColPtr, + const std::byte*& QuantBZeroPointColPtr, + const float*& BiasPtr, + float*& SumPtr +) +{ + QuantBDataColPtr += NumCols * StrideQuantBData; + QuantBScaleColPtr += NumCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NumCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NumCols : 0; + SumPtr += NumCols; +} + +template +MLAS_FORCEINLINE void +AdvanceRowPtrs( + size_t StrideQuantA, + size_t ldc, + const std::byte*& QuantARowPtr, + float*& SumRowPtr +) +{ + QuantARowPtr += NumRows * StrideQuantA; + SumRowPtr += NumRows * ldc; +} + +template +void +SQ4BitGemmKernel_CompInt8_BlkLen16( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 16; + + const size_t StrideQuantA = BlockCountK * Q8BlkSize(BlkLen); + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const std::byte* QuantARowPtr = QuantA; + + float* SumRowPtr = C; + + size_t m_remaining = CountM; + while (m_remaining > 1) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 1) { + // Compute 2x2 tiles of output + SQ4BitGemm_CompInt8_Compute2x2_BlkLen16( + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK, + StrideQuantA, + StrideQuantBData, + StrideQuantBScale, + StrideQuantBZeroPoint, + ldc + ); + + // Move to next 2 columns + AdvanceColPtrs<2, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 2; + } + + if (n_remaining > 0) { + // Compute last 2x1 tile of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( + QuantARowPtr + StrideQuantA, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr + ldc, + BlockCountK + ); + } + + // Move to next 2 rows + AdvanceRowPtrs<2>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 2; + } + + if (m_remaining > 0) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 0) { + // Compute 1x1 tiles of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + // Move to next column + AdvanceColPtrs<1, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 1; + } + } +} + +template +void +SQ4BitGemmKernel_CompInt8_BlkLen32( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 32; + + const size_t StrideQuantA = BlockCountK * Q8BlkSize(BlkLen); + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const std::byte* QuantARowPtr = QuantA; + + float* SumRowPtr = C; + + size_t m_remaining = CountM; + while (m_remaining > 1) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 1) { + // Compute 2x2 tiles of output + SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16( + BlkLen, + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK, + StrideQuantA, + StrideQuantBData, + StrideQuantBScale, + StrideQuantBZeroPoint, + ldc + ); + + // Move to next 2 columns + AdvanceColPtrs<2, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 2; + } + + if (n_remaining > 0) { + // Compute last 2x1 tile of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( + QuantARowPtr + StrideQuantA, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr + ldc, + BlockCountK + ); + } + + // Move to next 2 rows + AdvanceRowPtrs<2>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 2; + } + + if (m_remaining > 0) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 0) { + // Compute 1x1 tiles of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + // Move to next column + AdvanceColPtrs<1, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 1; + } + } +} + +template +void +SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + + const size_t StrideQuantA = BlockCountK * Q8BlkSize(BlkLen); + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const std::byte* QuantARowPtr = QuantA; + + float* SumRowPtr = C; + + size_t m_remaining = CountM; + while (m_remaining > 1) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 1) { + // Compute 2x2 tiles of output + SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16( + BlkLen, + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK, + StrideQuantA, + StrideQuantBData, + StrideQuantBScale, + StrideQuantBZeroPoint, + ldc + ); + + // Move to next 2 columns + AdvanceColPtrs<2, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 2; + } + + if (n_remaining > 0) { + // Compute last 2x1 tile of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( + BlkLen, + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( + BlkLen, + QuantARowPtr + StrideQuantA, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr + ldc, + BlockCountK + ); + } + + // Move to next 2 rows + AdvanceRowPtrs<2>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 2; + } + + if (m_remaining > 0) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + const float* BiasPtr = Bias; + + float* SumPtr = SumRowPtr; + + size_t n_remaining = CountN; + while (n_remaining > 0) { + // Compute 1x1 tiles of output + SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( + BlkLen, + QuantARowPtr, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr, + BlockCountK + ); + + // Move to next column + AdvanceColPtrs<1, HasZeroPoint>( + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr + ); + + n_remaining -= 1; + } + } +} + +template +void +SQ4BitGemmKernel_CompInt8_DispatchOnBlkLen( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + if (BlkLen == 16) { + SQ4BitGemmKernel_CompInt8_BlkLen16( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmKernel_CompInt8_BlkLen32( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } else { + SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } +} + +} // namespace + +size_t +SQ4BitGemmKernel_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + size_t ldc, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + constexpr bool HasZeroPoint = true; + SQ4BitGemmKernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } else { + constexpr bool HasZeroPoint = false; + SQ4BitGemmKernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + BlockCountK, + ldc, + Bias + ); + } + + return CountM; +} + +} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/optimizer/attention_fusion.h b/onnxruntime/core/optimizer/attention_fusion.h index acb478da5f31..befb66b5aa96 100644 --- a/onnxruntime/core/optimizer/attention_fusion.h +++ b/onnxruntime/core/optimizer/attention_fusion.h @@ -19,7 +19,8 @@ class AttentionFusion : public GraphTransformer { Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; private: - static bool FuseSubGraph(Node& layer_norm, const Node& add_after_layer_norm, Graph& graph, int64_t hidden_size, std::map& mask_index_map, const logging::Logger& logger); + static bool FuseSubGraph(Node& layer_norm, const Node& add_after_layer_norm, Graph& graph, int64_t hidden_size, + std::map& mask_index_map, const logging::Logger& logger); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/attention_fusion_helper.h b/onnxruntime/core/optimizer/attention_fusion_helper.h index 6233e5b839bb..ca744adddbee 100644 --- a/onnxruntime/core/optimizer/attention_fusion_helper.h +++ b/onnxruntime/core/optimizer/attention_fusion_helper.h @@ -23,7 +23,8 @@ struct MatchGemmResult { }; // Compare the expected parameters (starts, ends, axes and step) -bool CheckSliceParameters(const Graph& graph, const Node& slice, const std::vector& input_indices, const std::vector& expected_values, const logging::Logger& logger) { +bool CheckSliceParameters(const Graph& graph, const Node& slice, const std::vector& input_indices, + const std::vector& expected_values, const logging::Logger& logger) { ORT_ENFORCE(input_indices.size() == expected_values.size() && input_indices.size() > 0); // Here assumes that the last element of input_indices is the maximum one. diff --git a/onnxruntime/core/optimizer/common_subexpression_elimination.cc b/onnxruntime/core/optimizer/common_subexpression_elimination.cc index 48df511d0c67..471e4ee7c03a 100644 --- a/onnxruntime/core/optimizer/common_subexpression_elimination.cc +++ b/onnxruntime/core/optimizer/common_subexpression_elimination.cc @@ -491,7 +491,8 @@ Status CommonSubexpressionElimination::ApplyImpl(Graph& graph, bool& modified, i if (graph_outputs.count(output_def) > 0) { // Currently, we don't support eliminating the graph's outputs. - LOGS(logger, VERBOSE) << "Not eliminating output " << output_def->Name() << " of node " << node->Name() << "[" << node->OpType() << "] because it's the graph's output."; + LOGS(logger, VERBOSE) << "Not eliminating output " << output_def->Name() << " of node " << node->Name() + << "[" << node->OpType() << "] because it's the graph's output."; continue; } diff --git a/onnxruntime/core/optimizer/free_dim_override_transformer.h b/onnxruntime/core/optimizer/free_dim_override_transformer.h index f9553339a7ce..18e0b128b864 100644 --- a/onnxruntime/core/optimizer/free_dim_override_transformer.h +++ b/onnxruntime/core/optimizer/free_dim_override_transformer.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/optimizer/graph_transformer.h" diff --git a/onnxruntime/core/optimizer/gather_fusion.h b/onnxruntime/core/optimizer/gather_fusion.h index 098278a77daf..f431d98b3b82 100644 --- a/onnxruntime/core/optimizer/gather_fusion.h +++ b/onnxruntime/core/optimizer/gather_fusion.h @@ -10,7 +10,7 @@ namespace onnxruntime { /** @Class GatherSliceToSplitFusion -Fuse multiple Gather/Slice nodes that comsuming one output to one Split node. +Fuse multiple Gather/Slice nodes that consuming one output to one Split node. */ class GatherSliceToSplitFusion : public GraphTransformer { public: diff --git a/onnxruntime/core/optimizer/gemm_sum_fusion.h b/onnxruntime/core/optimizer/gemm_sum_fusion.h index 0e2ec104703f..9b2fa22ecc31 100644 --- a/onnxruntime/core/optimizer/gemm_sum_fusion.h +++ b/onnxruntime/core/optimizer/gemm_sum_fusion.h @@ -12,14 +12,14 @@ namespace onnxruntime { Rewrite rule that fuses Gemm and Sum nodes to a single Gemm node. This fusion can be applied in the following scenario: -1) Sum at output of Gemm: when the output of a Gemm is immedietly summed with +1) Sum at output of Gemm: when the output of a Gemm is immediately summed with exactly one other element, we can fuse this Sum with Gemm by using the other Sum input as C, provided that the C input to the Gemm is missing. This is supported for opset >= 11, as this is when Gemm input C became optional. TODO: Support the Add use case: Sum(x, y) ~= Add. -This patterm is attempted to be triggered only on nodes with op type "Gemm". +This pattern is attempted to be triggered only on nodes with op type "Gemm". A --> Gemm --> D --> Sum --> E ^ ^ diff --git a/onnxruntime/core/optimizer/identity_elimination.h b/onnxruntime/core/optimizer/identity_elimination.h index 5e76275207c3..4b20edec12df 100644 --- a/onnxruntime/core/optimizer/identity_elimination.h +++ b/onnxruntime/core/optimizer/identity_elimination.h @@ -26,6 +26,6 @@ class EliminateIdentity : public RewriteRule { bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; -}; // namespace onnxruntime +}; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index 5953935203b8..33fd613bb1a5 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -6,7 +6,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_external_data_info.h" #include "core/platform/env.h" diff --git a/onnxruntime/core/optimizer/isinf_reducesum_fusion.cc b/onnxruntime/core/optimizer/isinf_reducesum_fusion.cc index bc0523d517ee..7d249ea715e8 100644 --- a/onnxruntime/core/optimizer/isinf_reducesum_fusion.cc +++ b/onnxruntime/core/optimizer/isinf_reducesum_fusion.cc @@ -41,7 +41,7 @@ Status IsInfReduceSumFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l auto input_defs = isinf_node.MutableInputDefs(); // see if there is a Cast before IsInf - // This will happen if input type is FP16 but IsInf doesnt support fp16, so it will be cast to float/double + // This will happen if input type is FP16 but IsInf doesn't support fp16, so it will be cast to float/double // This Cast can be skipped as we are replacing the subgraph with IsAllFinite, which supports FP16 auto cast1_node_iter = isinf_node.InputNodesBegin(); if (cast1_node_iter != isinf_node.InputNodesEnd() && diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h index 7a43483cf37d..39cd0dd186d5 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.h +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h @@ -8,7 +8,7 @@ namespace onnxruntime { /* * This fusion submerges a BatchNormalization operator to it's super - * precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition() + * preceding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition() * is true. */ class MatmulBNFusion : public RewriteRule { @@ -24,4 +24,4 @@ class MatmulBNFusion : public RewriteRule { Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h index de36202afff2..f38820655117 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h @@ -5,7 +5,7 @@ #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers.h" #include "core/framework/node_unit.h" #include "core/graph/basic_types.h" diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.h b/onnxruntime/core/optimizer/selectors_actions/actions.h index e52ab16efe95..9384bfa7027c 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.h +++ b/onnxruntime/core/optimizer/selectors_actions/actions.h @@ -6,7 +6,7 @@ #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/graph/graph_utils.h" // TODO: Minimize usage of this given we want to use Actions in a minimal build #include "core/graph/runtime_optimization_record.h" #include "core/optimizer/selectors_actions/helpers.h" diff --git a/onnxruntime/core/optimizer/selectors_actions/helpers.h b/onnxruntime/core/optimizer/selectors_actions/helpers.h index cf5489dc1960..c3d50d6de05a 100644 --- a/onnxruntime/core/optimizer/selectors_actions/helpers.h +++ b/onnxruntime/core/optimizer/selectors_actions/helpers.h @@ -4,7 +4,7 @@ #pragma once #include "core/common/basic_types.h" -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers.h" #include "core/graph/graph.h" #include "core/graph/runtime_optimization_record.h" diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index 0d7ab70eba61..f1e94dd4fe9e 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -11,7 +11,7 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { // implements MemCpy node insertion in graph transform -// note that GraphTransformer::Apply() is supposed to be stateless, so this cannot derive from GraphTranformer +// note that GraphTransformer::Apply() is supposed to be stateless, so this cannot derive from GraphTransformer class TransformerMemcpyImpl { public: TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider) diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index e6ffd0d91372..d4ed9c4e26cc 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -12,7 +12,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/common/make_string.h" #include "core/graph/constants.h" diff --git a/onnxruntime/core/platform/env.h b/onnxruntime/core/platform/env.h index 6917f42091bf..f4dff2c49113 100644 --- a/onnxruntime/core/platform/env.h +++ b/onnxruntime/core/platform/env.h @@ -22,7 +22,7 @@ limitations under the License. #include #include #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/path_string.h" diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 9999550c241c..16d135c3acb2 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -44,7 +44,7 @@ limitations under the License. #endif #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/platform/scoped_resource.h" diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index dc090e446e60..712b69593a68 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -26,7 +26,7 @@ limitations under the License. #include #include -#include "core/common/gsl.h" +#include #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/common/span_utils.h" diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.cc b/onnxruntime/core/platform/windows/logging/etw_sink.cc index b0f9eaf4f62d..2a74e2285065 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.cc +++ b/onnxruntime/core/platform/windows/logging/etw_sink.cc @@ -98,10 +98,6 @@ ULONGLONG EtwRegistrationManager::Keyword() const { return keyword_; } -HRESULT EtwRegistrationManager::Status() const { - return etw_status_; -} - void EtwRegistrationManager::RegisterInternalCallback(const EtwInternalCallback& callback) { std::lock_guard lock(callbacks_mutex_); callbacks_.push_back(&callback); @@ -144,15 +140,9 @@ EtwRegistrationManager::EtwRegistrationManager() { } void EtwRegistrationManager::LazyInitialize() { - if (!initialized_) { - std::lock_guard lock(init_mutex_); - if (!initialized_) { // Double-check locking pattern - initialized_ = true; - etw_status_ = ::TraceLoggingRegisterEx(etw_provider_handle, ORT_TL_EtwEnableCallback, nullptr); - if (FAILED(etw_status_)) { - ORT_THROW("ETW registration failed. Logging will be broken: " + std::to_string(etw_status_)); - } - } + static HRESULT etw_status = ::TraceLoggingRegisterEx(etw_provider_handle, ORT_TL_EtwEnableCallback, nullptr); + if (FAILED(etw_status)) { + ORT_THROW("ETW registration failed. Logging will be broken: " + std::to_string(etw_status)); } } @@ -171,12 +161,6 @@ void EtwSink::SendImpl(const Timestamp& timestamp, const std::string& logger_id, // register on first usage static EtwRegistrationManager& etw_manager = EtwRegistrationManager::Instance(); - // do something (not that meaningful) with etw_manager so it doesn't get optimized out - // as we want an instance around to do the unregister - if (FAILED(etw_manager.Status())) { - return; - } - // TODO: Validate if this filtering makes sense. if (message.DataType() == DataType::USER) { return; diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.h b/onnxruntime/core/platform/windows/logging/etw_sink.h index 3af45b813a62..ff68aec0b7d6 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.h +++ b/onnxruntime/core/platform/windows/logging/etw_sink.h @@ -66,9 +66,6 @@ class EtwRegistrationManager { // Get the current keyword uint64_t Keyword() const; - // Get the ETW registration status - HRESULT Status() const; - void RegisterInternalCallback(const EtwInternalCallback& callback); void UnregisterInternalCallback(const EtwInternalCallback& callback); @@ -100,7 +97,6 @@ class EtwRegistrationManager { bool is_enabled_; UCHAR level_; ULONGLONG keyword_; - HRESULT etw_status_; }; } // namespace logging diff --git a/onnxruntime/core/platform/windows/stacktrace.cc b/onnxruntime/core/platform/windows/stacktrace.cc index d7d423e4a483..3401507ae911 100644 --- a/onnxruntime/core/platform/windows/stacktrace.cc +++ b/onnxruntime/core/platform/windows/stacktrace.cc @@ -12,7 +12,7 @@ #endif #include "core/common/logging/logging.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index 3e6c43ab0786..97fb83b6dc48 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -7,7 +7,7 @@ #include -#include "core/common/gsl.h" +#include #include "core/common/status.h" #include "core/graph/basic_types.h" #include "core/providers/common.h" diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc index 6c2fcc2ace85..3400f09b4056 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc @@ -259,13 +259,13 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa // Onnx spec requires output sizes to be a positive integer, so we are not checking that here if (output_size_h % input_size_h != 0) { LOGS(logger, VERBOSE) << "Resize: output_size_h: " << output_size_h - << " is not a mutliple of input_size_h: " << input_size_h; + << " is not a multiple of input_size_h: " << input_size_h; return false; } if (output_size_w % input_size_w != 0) { LOGS(logger, VERBOSE) << "Resize: output_size_w: " << output_size_w - << " is not a mutliple of input_size_w: " << input_size_w; + << " is not a multiple of input_size_w: " << input_size_w; return false; } } diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index e3cd43d786fc..c4c3b38bba51 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -8,7 +8,7 @@ #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/common/logging/logging.h" #include "core/common/status.h" #include "core/platform/ort_mutex.h" diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 3edcdb3f95e4..1d506099b436 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -13,7 +13,7 @@ #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers.h" #include "core/common/logging/logging.h" #include "core/common/narrow.h" diff --git a/onnxruntime/core/providers/coreml/shape_utils.h b/onnxruntime/core/providers/coreml/shape_utils.h index 0a1fd47cfdfe..23ee51af63d4 100644 --- a/onnxruntime/core/providers/coreml/shape_utils.h +++ b/onnxruntime/core/providers/coreml/shape_utils.h @@ -7,7 +7,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/common/logging/logging.h" #include "core/graph/node_arg.h" diff --git a/onnxruntime/core/providers/cpu/controlflow/loop.cc b/onnxruntime/core/providers/cpu/controlflow/loop.cc index 9837aabe786c..c65dd2a04bf5 100644 --- a/onnxruntime/core/providers/cpu/controlflow/loop.cc +++ b/onnxruntime/core/providers/cpu/controlflow/loop.cc @@ -23,7 +23,7 @@ #include "core/framework/TensorSeq.h" #include "core/providers/utils.h" -#include "core/common/gsl.h" +#include #ifdef _MSC_VER #pragma warning(pop) diff --git a/onnxruntime/core/providers/cpu/controlflow/scan.h b/onnxruntime/core/providers/cpu/controlflow/scan.h index 76c27827b99a..8516fa786da3 100644 --- a/onnxruntime/core/providers/cpu/controlflow/scan.h +++ b/onnxruntime/core/providers/cpu/controlflow/scan.h @@ -3,7 +3,7 @@ #pragma once #include -#include "core/common/gsl.h" +#include #ifndef SHARED_PROVIDER #include "core/common/common.h" diff --git a/onnxruntime/core/providers/cpu/controlflow/scan_8.cc b/onnxruntime/core/providers/cpu/controlflow/scan_8.cc index 67853790415e..cea321757812 100644 --- a/onnxruntime/core/providers/cpu/controlflow/scan_8.cc +++ b/onnxruntime/core/providers/cpu/controlflow/scan_8.cc @@ -246,9 +246,9 @@ Status Scan8Impl::ValidateSubgraphInput(int start_input, int end_input, bool is_ auto this_batch_size = input_shape[0]; - if (batch_size_ < 0) + if (batch_size_ < 0) { batch_size_ = this_batch_size; - else { + } else { if (batch_size_ != this_batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Scan inputs have inconsistent batch size. Previous value was ", batch_size_, " but ", graph_inputs[i]->Name(), " has batch size of ", @@ -263,7 +263,8 @@ Status Scan8Impl::ValidateSubgraphInput(int start_input, int end_input, bool is_ max_sequence_len_ = this_seq_len; } else { if (max_sequence_len_ != this_seq_len) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Scan inputs have inconsistent sequence lengths. Previous value was ", + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Scan inputs have inconsistent sequence lengths. Previous value was ", max_sequence_len_, " but ", graph_inputs[i]->Name(), " has length of ", this_seq_len); } @@ -396,13 +397,15 @@ Status Scan8Impl::Execute(const FeedsFetchesManager& ffm) { // Setup input OrtValue streams std::vector::Iterator> scan_input_stream_iterators; - scan_input_stream_iterators.reserve(static_cast(info_.num_variadic_inputs) - info_.num_loop_state_variables); + scan_input_stream_iterators.reserve(static_cast(info_.num_variadic_inputs) - + info_.num_loop_state_variables); for (int i = info_.num_loop_state_variables, end = info_.num_variadic_inputs; i < end; ++i) { const auto& ort_value = GetSubgraphInputMLValue(context_, i); // forward - if (directions_[static_cast(i) - info_.num_loop_state_variables] == static_cast(ScanDirection::kForward)) { + if (directions_[static_cast(i) - info_.num_loop_state_variables] == + static_cast(ScanDirection::kForward)) { // the iterator is self contained, so we don't need to keep the OrtValueTensorSlicer instance around scan_input_stream_iterators.push_back(device_helpers_.create_const_slicer_func(ort_value, 1, b).begin()); } else { // reverse @@ -417,8 +420,10 @@ Status Scan8Impl::Execute(const FeedsFetchesManager& ffm) { } // Call the subgraph for each item in the sequence - status = IterateSequence(context_, session_state_, batch_loop_state_variables[onnxruntime::narrow(b)], scan_input_stream_iterators, - sequence_len, info_.num_loop_state_variables, info_.num_variadic_inputs, info_.num_outputs, + status = IterateSequence(context_, session_state_, batch_loop_state_variables[onnxruntime::narrow(b)], + scan_input_stream_iterators, + sequence_len, info_.num_loop_state_variables, info_.num_variadic_inputs, + info_.num_outputs, implicit_inputs_, output_iterators_, ffm); // zero out any remaining values in the sequence diff --git a/onnxruntime/core/providers/cpu/controlflow/scan_9.cc b/onnxruntime/core/providers/cpu/controlflow/scan_9.cc index f7548fbf6050..24d233c0594f 100644 --- a/onnxruntime/core/providers/cpu/controlflow/scan_9.cc +++ b/onnxruntime/core/providers/cpu/controlflow/scan_9.cc @@ -21,7 +21,7 @@ #include "core/providers/cpu/tensor/utils.h" #include "core/providers/cpu/tensor/transpose.h" -#include "core/common/gsl.h" +#include #ifdef _MSC_VER #pragma warning(pop) diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index c4a83efa01a9..fd7b19dea724 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -192,6 +192,11 @@ struct ProviderHostCPUImpl : ProviderHostCPU { Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } + void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, + gsl::span input_dims, + InlinedVector& scales) const override { + p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales); + } #ifndef DISABLE_CONTRIB_OPS Status embed_layer_norm__CheckInputs(const OpKernelContext* context, bool quantizedVersion) override { @@ -294,12 +299,6 @@ struct ProviderHostCPUImpl : ProviderHostCPU { Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); } Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); } - void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, - gsl::span input_dims, - InlinedVector& scales) const override { - p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales); - } - #ifdef ENABLE_ATEN Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); } #endif diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index c0e674827e4d..840d6f8e3e7a 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -141,7 +141,9 @@ struct ProviderHostCPU { virtual Status Scan__Compute(const Scan<9>* p, OpKernelContext* ctx) = 0; virtual Status Scan__SetupSubgraphExecutionInfo(Scan<8>* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; virtual Status Scan__SetupSubgraphExecutionInfo(Scan<9>* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; - + virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, + gsl::span input_dims, + InlinedVector& scales) const = 0; #ifndef DISABLE_CONTRIB_OPS virtual Status embed_layer_norm__CheckInputs(const OpKernelContext* context, bool quantizedVersion) = 0; virtual Status bias_gelu_helper__CheckInputs(const OpKernelContext* context) = 0; @@ -203,10 +205,6 @@ struct ProviderHostCPU { virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0; virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; - virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, - gsl::span input_dims, - InlinedVector& scales) const = 0; - #ifdef ENABLE_ATEN virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0; #endif diff --git a/onnxruntime/core/providers/cpu/generator/random.h b/onnxruntime/core/providers/cpu/generator/random.h index 2ff6549794ff..8a0390fe7af8 100644 --- a/onnxruntime/core/providers/cpu/generator/random.h +++ b/onnxruntime/core/providers/cpu/generator/random.h @@ -4,7 +4,7 @@ #pragma once #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/framework/op_kernel.h" diff --git a/onnxruntime/core/providers/cpu/math/hardmax.h b/onnxruntime/core/providers/cpu/math/hardmax.h index 02b9b96fd3bf..1b77a30a164e 100644 --- a/onnxruntime/core/providers/cpu/math/hardmax.h +++ b/onnxruntime/core/providers/cpu/math/hardmax.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/framework/op_kernel.h" diff --git a/onnxruntime/core/providers/cpu/math/sign.cc b/onnxruntime/core/providers/cpu/math/sign.cc index 60080135bbd2..1d3b444c83b6 100644 --- a/onnxruntime/core/providers/cpu/math/sign.cc +++ b/onnxruntime/core/providers/cpu/math/sign.cc @@ -3,7 +3,7 @@ #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/framework/data_types.h" diff --git a/onnxruntime/core/providers/cpu/math/softmax.h b/onnxruntime/core/providers/cpu/math/softmax.h index 448a97bfbe0a..cac674b42945 100644 --- a/onnxruntime/core/providers/cpu/math/softmax.h +++ b/onnxruntime/core/providers/cpu/math/softmax.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/framework/op_kernel.h" diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index 0b6c35ffabb1..cae20b42725b 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -22,7 +22,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/util/math.h" #include "core/util/math_cpuonly.h" diff --git a/onnxruntime/core/providers/cpu/ml/cast_map.cc b/onnxruntime/core/providers/cpu/ml/cast_map.cc index 8dcc3393f581..f21eb99dd64e 100644 --- a/onnxruntime/core/providers/cpu/ml/cast_map.cc +++ b/onnxruntime/core/providers/cpu/ml/cast_map.cc @@ -3,7 +3,7 @@ #include "core/providers/cpu/ml/cast_map.h" #include -#include "core/common/gsl.h" +#include using namespace ::onnxruntime::common; namespace { diff --git a/onnxruntime/core/providers/cpu/ml/category_mapper.cc b/onnxruntime/core/providers/cpu/ml/category_mapper.cc index f56e98a286cf..88d83b2a2d33 100644 --- a/onnxruntime/core/providers/cpu/ml/category_mapper.cc +++ b/onnxruntime/core/providers/cpu/ml/category_mapper.cc @@ -3,7 +3,7 @@ #include "core/providers/cpu/ml/category_mapper.h" #include -#include "core/common/gsl.h" +#include using namespace ::onnxruntime::common; namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/ml/feature_vectorizer.cc b/onnxruntime/core/providers/cpu/ml/feature_vectorizer.cc index 24fbfac59473..6e46b2279f7d 100644 --- a/onnxruntime/core/providers/cpu/ml/feature_vectorizer.cc +++ b/onnxruntime/core/providers/cpu/ml/feature_vectorizer.cc @@ -3,7 +3,7 @@ #include "core/providers/cpu/ml/feature_vectorizer.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace ml { diff --git a/onnxruntime/core/providers/cpu/ml/label_encoder.cc b/onnxruntime/core/providers/cpu/ml/label_encoder.cc index 65102b62a963..67f38638d8da 100644 --- a/onnxruntime/core/providers/cpu/ml/label_encoder.cc +++ b/onnxruntime/core/providers/cpu/ml/label_encoder.cc @@ -3,7 +3,7 @@ #include "core/providers/cpu/ml/label_encoder.h" #include -#include "core/common/gsl.h" +#include using namespace ::onnxruntime::common; namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/nn/flatten.h b/onnxruntime/core/providers/cpu/nn/flatten.h index 45fb644d2a0e..b776d5e28d57 100644 --- a/onnxruntime/core/providers/cpu/nn/flatten.h +++ b/onnxruntime/core/providers/cpu/nn/flatten.h @@ -5,7 +5,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/common/gsl.h" +#include #include "core/providers/cpu/tensor/utils.h" #include "core/providers/common.h" diff --git a/onnxruntime/core/providers/cpu/nn/lrn.h b/onnxruntime/core/providers/cpu/nn/lrn.h index e797ffda87f7..dc27672aa056 100644 --- a/onnxruntime/core/providers/cpu/nn/lrn.h +++ b/onnxruntime/core/providers/cpu/nn/lrn.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/exceptions.h" diff --git a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h index dfc7a2b68699..6d54c24b3808 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h +++ b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h @@ -18,7 +18,7 @@ #include "core/common/safeint.h" #include "core/platform/threadpool.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace rnn { diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index cbc4d8360d4b..6742bab4fa4a 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -6,7 +6,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/narrow.h" diff --git a/onnxruntime/core/providers/cpu/tensor/mean_variance_normalization.cc b/onnxruntime/core/providers/cpu/tensor/mean_variance_normalization.cc index 62c3dbfc87a7..6eef6b35a8fd 100644 --- a/onnxruntime/core/providers/cpu/tensor/mean_variance_normalization.cc +++ b/onnxruntime/core/providers/cpu/tensor/mean_variance_normalization.cc @@ -6,7 +6,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/providers/common.h" #include "core/providers/cpu/tensor/transpose.h" #include "core/util/math_cpuonly.h" diff --git a/onnxruntime/core/providers/cpu/tensor/reverse_sequence.cc b/onnxruntime/core/providers/cpu/tensor/reverse_sequence.cc index a03d1143d058..31ce89b3ed55 100644 --- a/onnxruntime/core/providers/cpu/tensor/reverse_sequence.cc +++ b/onnxruntime/core/providers/cpu/tensor/reverse_sequence.cc @@ -11,7 +11,7 @@ #pragma warning(disable : 4996) #endif -#include "core/common/gsl.h" +#include #ifdef _MSC_VER #pragma warning(pop) diff --git a/onnxruntime/core/providers/cpu/tensor/shape_op.h b/onnxruntime/core/providers/cpu/tensor/shape_op.h index b9e938995019..05d22595dd83 100644 --- a/onnxruntime/core/providers/cpu/tensor/shape_op.h +++ b/onnxruntime/core/providers/cpu/tensor/shape_op.h @@ -9,7 +9,7 @@ #include "core/framework/op_kernel.h" #endif -#include "core/common/gsl.h" +#include #include namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/tensor/slice_compute_metadata.h b/onnxruntime/core/providers/cpu/tensor/slice_compute_metadata.h index 05f1479315fe..c21cc13ad316 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice_compute_metadata.h +++ b/onnxruntime/core/providers/cpu/tensor/slice_compute_metadata.h @@ -6,7 +6,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/tensor_shape.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/tensor/split.cc b/onnxruntime/core/providers/cpu/tensor/split.cc index 724814e2f1d5..4e43085fe288 100644 --- a/onnxruntime/core/providers/cpu/tensor/split.cc +++ b/onnxruntime/core/providers/cpu/tensor/split.cc @@ -4,7 +4,7 @@ #include "core/providers/cpu/tensor/split.h" #include "core/common/narrow.h" -#include "core/common/gsl.h" +#include #include "core/common/safeint.h" #include "core/framework/copy.h" #include "core/framework/element_type_lists.h" diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.h b/onnxruntime/core/providers/cpu/tensor/transpose.h index fda41c28a256..54d3584ba0da 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.h +++ b/onnxruntime/core/providers/cpu/tensor/transpose.h @@ -10,7 +10,7 @@ #include "core/framework/op_kernel.h" #endif -#include "core/common/gsl.h" +#include #include namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/tensor/unique.cc b/onnxruntime/core/providers/cpu/tensor/unique.cc index 135bef0860ca..ab99d87da83f 100644 --- a/onnxruntime/core/providers/cpu/tensor/unique.cc +++ b/onnxruntime/core/providers/cpu/tensor/unique.cc @@ -4,7 +4,7 @@ #include "core/providers/cpu/tensor/unique.h" #include #include -#include "core/common/gsl.h" +#include #include "core/framework/op_kernel_type_control_utils.h" #include "core/providers/common.h" #include "core/providers/op_kernel_type_control.h" diff --git a/onnxruntime/core/providers/cpu/tensor/utils.h b/onnxruntime/core/providers/cpu/tensor/utils.h index 17eac5417f0a..6adcfec85269 100644 --- a/onnxruntime/core/providers/cpu/tensor/utils.h +++ b/onnxruntime/core/providers/cpu/tensor/utils.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once -#include "core/common/gsl.h" +#include #include "core/common/narrow.h" #ifndef SHARED_PROVIDER diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 61da125b4095..0b56cac1038e 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -16,7 +16,7 @@ #include "core/providers/cuda/cuda_pch.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/shared_inc/fast_divmod.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace cuda { diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 103c79c93b2c..7851da7fa91a 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -10,7 +10,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cuda/cuda_execution_provider_info.h" diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index 3c0bf183362d..58e57572131b 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -81,6 +81,9 @@ CudaStream::CudaStream(cudaStream_t stream, cudnn_handle_ = external_cudnn_handle; CUDNN_CALL_THROW(cudnnSetStream(cudnn_handle_, stream)); } +#else + (void)(external_cudnn_handle); + (void)(external_cublas_handle); #endif } diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index 914dc02a9eda..31fc63a86d64 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -6,7 +6,7 @@ #include "core/providers/cuda/cudnn_common.h" #include "core/common/inlined_containers.h" -#include "core/common/gsl.h" +#include #include "shared_inc/cuda_call.h" #include "core/providers/cpu/tensor/utils.h" #ifndef USE_CUDA_MINIMAL diff --git a/onnxruntime/core/providers/cuda/math/softmax.h b/onnxruntime/core/providers/cuda/math/softmax.h index bbe63e66e67d..6f4016b655c9 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.h +++ b/onnxruntime/core/providers/cuda/math/softmax.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/providers/cuda/cuda_kernel.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cuda/multi_tensor/common.cuh b/onnxruntime/core/providers/cuda/multi_tensor/common.cuh index 2d7928776bb8..49a0a9c51472 100644 --- a/onnxruntime/core/providers/cuda/multi_tensor/common.cuh +++ b/onnxruntime/core/providers/cuda/multi_tensor/common.cuh @@ -10,7 +10,7 @@ #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace cuda { diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 2bc937526f94..7b827f8d0459 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/cudnn_common.h" diff --git a/onnxruntime/core/providers/cuda/rnn/gru.h b/onnxruntime/core/providers/cuda/rnn/gru.h index 6f5c5ab6e956..e5ea1ed3e670 100644 --- a/onnxruntime/core/providers/cuda/rnn/gru.h +++ b/onnxruntime/core/providers/cuda/rnn/gru.h @@ -4,7 +4,7 @@ #pragma once #include "cudnn_rnn_base.h" -#include "core/common/gsl.h" +#include #include "core/providers/cuda/cuda_common.h" #include diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h index 2df0a38d22f5..1f7df9b6fc2e 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h @@ -11,7 +11,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/float16.h" #include "core/providers/cuda/shared_inc/fast_divmod.h" diff --git a/onnxruntime/core/providers/cuda/tensor/transpose.h b/onnxruntime/core/providers/cuda/tensor/transpose.h index da73d1308dbf..31bb21078c1b 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose.h +++ b/onnxruntime/core/providers/cuda/tensor/transpose.h @@ -4,7 +4,7 @@ #pragma once #include "core/providers/shared_library/provider_api.h" -#include "core/common/gsl.h" +#include #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cpu/tensor/transpose.h" diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp index c6a87da705a9..32d6af73aae8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp @@ -40,16 +40,32 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0) ); } + MLOperatorTensorDataType ADatatype = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::A).tensorDataType; MLOperatorTensorDataType BDatatype = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::B).tensorDataType; + gsl::span outputSizes = m_outputTensorDescs[0].GetSizes(); std::vector ATensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::A); - std::vector ExpectedAScaleTensorShape = {1, 1, 1, 1}; - std::vector ExpectedAZeroPointTensorShape = {1, 1, 1, 1}; + std::vector BTensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::B); + std::vector ExpectedAScaleTensorShape(outputSizes.size(), 1); + std::vector ExpectedAZeroPointTensorShape(outputSizes.size(), 1); + ML_CHECK_VALID_ARGUMENT(outputSizes.size() >= 4); + ML_CHECK_VALID_ARGUMENT(ATensorShape.size() >= 2); + ML_CHECK_VALID_ARGUMENT(BTensorShape.size() >= 2); + ML_CHECK_VALID_ARGUMENT(ATensorShape.size() + 2 >= outputSizes.size()); + ML_CHECK_VALID_ARGUMENT(BTensorShape.size() + 2 >= outputSizes.size()); + std::vector AShapeBroadcasted(outputSizes.begin(), outputSizes.end()); + std::copy(ATensorShape.end() - (outputSizes.size() - 2), + ATensorShape.end(), + AShapeBroadcasted.begin() + 2); + std::vector BShapeBroadcasted(outputSizes.begin(), outputSizes.end()); + std::copy(BTensorShape.end() - (outputSizes.size() - 2), + BTensorShape.end(), + BShapeBroadcasted.begin() + 2); // output edges between DynQL and MMItoFloat node TensorDesc intermediateQuantizedATensorDesc = TensorDesc( BDatatype, - gsl::make_span(ATensorShape), + gsl::make_span(AShapeBroadcasted), gsl::make_span(ATensorShape), TensorAxis::DoNotCoerce, TensorAxis::W, @@ -80,6 +96,30 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator 0 // guaranteedBaseOffsetAlignment ); + TensorDesc broadcastedATensorDesc = TensorDesc( + ADatatype, + AShapeBroadcasted, // Desired dimensions of tensor (after any broadcasting). + ATensorShape, // Original dimensions (before any broadcasting). Usually same as 'dimensions'. + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + TensorDesc broadcastedBTensorDesc = TensorDesc( + BDatatype, + BShapeBroadcasted, // Desired dimensions of tensor (after any broadcasting). + BTensorShape, // Original dimensions (before any broadcasting). Usually same as 'dimensions'. + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + DML_TENSOR_DESC namedBroadcastedATensorDesc = broadcastedATensorDesc.GetDmlDesc(); + DML_TENSOR_DESC namedBroadcastedBTensorDesc = broadcastedBTensorDesc.GetDmlDesc(); DML_TENSOR_DESC namedIntermediateQuantizedATensorDesc = intermediateQuantizedATensorDesc.GetDmlDesc(); DML_TENSOR_DESC namedIntermediateQuantizedAScaleTensorDesc = intermediateQuantizedAScaleTensorDesc.GetDmlDesc(); DML_TENSOR_DESC namedIntermediateQuantizedAZeroPointTensorDesc = intermediateQuantizedAZeroPointTensorDesc.GetDmlDesc(); @@ -88,7 +128,7 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator std::vector outputDescs = GetDmlOutputDescs(); DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC dynamicQuantizeLinearOperatorDesc = {}; - dynamicQuantizeLinearOperatorDesc.InputTensor = &inputDescs[OnnxInputIndex::A]; + dynamicQuantizeLinearOperatorDesc.InputTensor = &namedBroadcastedATensorDesc; dynamicQuantizeLinearOperatorDesc.OutputTensor = &namedIntermediateQuantizedATensorDesc; dynamicQuantizeLinearOperatorDesc.OutputScaleTensor = &namedIntermediateQuantizedAScaleTensorDesc; dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor = &namedIntermediateQuantizedAZeroPointTensorDesc; @@ -99,7 +139,7 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator matrixMultiplyIntergerToFloatOperatorDesc.ATensor = dynamicQuantizeLinearOperatorDesc.OutputTensor; matrixMultiplyIntergerToFloatOperatorDesc.AScaleTensor = dynamicQuantizeLinearOperatorDesc.OutputScaleTensor; matrixMultiplyIntergerToFloatOperatorDesc.AZeroPointTensor = dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor; - matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &inputDescs[OnnxInputIndex::B]; + matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &namedBroadcastedBTensorDesc; matrixMultiplyIntergerToFloatOperatorDesc.BScaleTensor = &inputDescs[OnnxInputIndex::B_scale]; matrixMultiplyIntergerToFloatOperatorDesc.BZeroPointTensor = hasBZP? &inputDescs[OnnxInputIndex::B_zero_point] : nullptr; matrixMultiplyIntergerToFloatOperatorDesc.BiasTensor = hasBias? &inputDescs[OnnxInputIndex::Bias] : nullptr; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h index 5ecc49f23748..e9df3fd20aff 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h @@ -26,7 +26,7 @@ #include #include -#include "core/common/gsl.h" +#include #ifdef _GAMING_XBOX_SCARLETT #include diff --git a/onnxruntime/core/providers/dml/GraphTransformers/precomp.h b/onnxruntime/core/providers/dml/GraphTransformers/precomp.h index a2cce6baed7b..7b146e3c4d9b 100644 --- a/onnxruntime/core/providers/dml/GraphTransformers/precomp.h +++ b/onnxruntime/core/providers/dml/GraphTransformers/precomp.h @@ -14,4 +14,4 @@ #include #include -#include "core/common/gsl.h" +#include diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index 32079a90ea73..686cdbe774a4 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -5,7 +5,7 @@ #include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h" #include "MLOperatorAuthorPrivate.h" -#include "core/common/gsl.h" +#include #include #ifdef ORT_NO_EXCEPTIONS diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/precomp.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/precomp.h index a64d1e01c6cc..6c47e60e63b8 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/precomp.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/precomp.h @@ -17,4 +17,4 @@ #include #include -#include "core/common/gsl.h" +#include diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_conv.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_conv.h index 718469f740d4..831b10c3e147 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_conv.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_conv.h @@ -66,9 +66,9 @@ class DnnlConv { * - For Onnx a non-dilated kernel would be all 1s * - For OneDNN a non-dilated kernel would be all 0s * - * The memory dimentions returned is in the form expected for OneDNN each dilation dimention - * will be 1 less than the dilated dimention expected by Onnx specification. Be aware of this - * fact as 'dilations' are used in any calcuations since this could result in an off-by-one + * The memory dimensions returned is in the form expected for OneDNN each dilation dimension + * will be 1 less than the dilated dimension expected by Onnx specification. Be aware of this + * fact as 'dilations' are used in any calculations since this could result in an off-by-one * error. */ dnnl::memory::dims GetDilations(DnnlNode& node, ConvShape shape); @@ -115,4 +115,4 @@ class DnnlConv { }; } // namespace ort_dnnl -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_convgrad.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_convgrad.h index f0928974b131..3a27788745ef 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_convgrad.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_convgrad.h @@ -49,9 +49,9 @@ class DnnlConvGrad { * - For Onnx a non-dilated kernel would be all 1s * - For OneDNN a non-dilated kernel would be all 0s * - * The memory dimentions returned is in the form expected for OneDNN each dilation dimention - * will be 1 less than the dilated dimention expected by Onnx specification. Be aware of this - * fact as 'dilations' are used in any calcuations since this could result in an off-by-one + * The memory dimensions returned is in the form expected for OneDNN each dilation dimension + * will be 1 less than the dilated dimension expected by Onnx specification. Be aware of this + * fact as 'dilations' are used in any calculations since this could result in an off-by-one * error. */ dnnl::memory::dims GetDilations(DnnlNode& node, ConvShape shape); @@ -62,4 +62,4 @@ class DnnlConvGrad { }; } // namespace ort_dnnl -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.h index d1cea23fca24..dac4e743ea19 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_qattention.h @@ -20,7 +20,7 @@ class DnnlQAttention { MASK_INDEX = 5, INPUT_ZP = 6, WEIGHTS_ZP = 7, - PAST = 8 // not suppoted + PAST = 8 // not supported }; enum OutputTensors : int { diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index 59dfbb0e9492..c51bf5ce9d4a 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -5,7 +5,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/providers/cpu/nn/conv_transpose_attributes.h" #include "core/providers/js/js_kernel.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/js/operators/transpose.h b/onnxruntime/core/providers/js/operators/transpose.h index f43dd814aa95..7a945471c770 100644 --- a/onnxruntime/core/providers/js/operators/transpose.h +++ b/onnxruntime/core/providers/js/operators/transpose.h @@ -4,7 +4,7 @@ #pragma once #include "core/providers/js/js_kernel.h" -#include "core/common/gsl.h" +#include #include "core/providers/cpu/tensor/transpose.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h index d4967b625182..6f76d04a9691 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h @@ -9,7 +9,7 @@ #include "core/common/inlined_containers.h" #include "core/graph/basic_types.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksTypes.h" -#include "core/common/gsl.h" +#include // This is the minimal Android API Level required by ORT NNAPI EP to run // ORT running on any host system with Android API level less than this will fall back to CPU EP diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index d0ae32378379..12416ea0c121 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -56,7 +56,13 @@ DEFINE_ADD_OPERAND_FROM_SCALAR(float, FLOAT32); #undef DEFINE_ADD_OPERAND_FROM_SCALAR void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) { - skipped_initializers_.insert(tensor_name); + // decrement usage count if this is a known initializer. + // For simplicity the OpBuilder::AddInitializersToSkip implementations may call this for arbitrary input names + // without first checking if the value is an initializer. + auto entry = initializer_usage_.find(tensor_name); + if (entry != initializer_usage_.end()) { + entry->second -= 1; + } } Status ModelBuilder::Prepare() { @@ -87,7 +93,16 @@ static size_t GetPaddedByteSize(size_t size) { } void ModelBuilder::PreprocessInitializers() { + const auto& initializers = GetInitializerTensors(); + for (const auto& node_unit : node_unit_holder_) { + // find all initializers consumed. AddInitializersToSkip will potentially decrement the usage count. + for (const auto& input : node_unit->Inputs()) { + if (input.node_arg.Exists() && Contains(initializers, input.node_arg.Name())) { + initializer_usage_[input.node_arg.Name()]++; + } + } + if (const auto* op_builder = GetOpBuilder(*node_unit)) { op_builder->AddInitializersToSkip(*this, *node_unit); } @@ -208,11 +223,16 @@ Status ModelBuilder::RegisterInitializers() { std::vector> initializers(initializer_size); size_t sizeAll = 0; + const auto should_skip_initializer = [this](const std::string& name) -> bool { + const auto it = initializer_usage_.find(name); + return it == initializer_usage_.end() || it->second == 0; + }; + int i = 0; for (const auto& pair : initializer_tensors) { const auto& tensor = *pair.second; const auto& name = tensor.name(); - if (Contains(skipped_initializers_, name)) + if (should_skip_initializer(name)) continue; Shape shape; @@ -249,7 +269,7 @@ Status ModelBuilder::RegisterInitializers() { size_t offset = 0; for (const auto& pair : initializer_tensors) { const auto& tensor = *pair.second; - if (Contains(skipped_initializers_, tensor.name())) + if (should_skip_initializer(tensor.name())) continue; auto [index, size, padded_size] = initializers[i++]; @@ -439,10 +459,11 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer( Status ModelBuilder::AddOperations() { const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); for (const auto node_idx : node_indices) { - LOGS_DEFAULT(VERBOSE) << "Adding node [" << node_idx << "]"; const auto* node(graph_viewer_.GetNode(node_idx)); const NodeUnit& node_unit = GetNodeUnit(node); + LOGS_DEFAULT(VERBOSE) << "Adding node [" << node_unit.Name() << "] at index [" << node_unit.Index() << "]"; + // Since we may have NodeUnit with multiple nodes, insert NodeUnit with the first occurrence of // its node(s) in topological order may cause the incorrect topological order while inserting // NodeUNits, for example, diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h index 8eddf389d3b5..b2118150dd30 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h @@ -134,7 +134,7 @@ class ModelBuilder { std::unordered_set operands_; std::unordered_set fused_activations_; - std::unordered_set skipped_initializers_; + std::unordered_map initializer_usage_; // All activation nodes (Relu, Relu1, Relu6) as a map std::unordered_map activation_node_units_; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc index c1770e0119b2..1c82d5e7452f 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc @@ -8,7 +8,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers_fwd.h" #include "core/common/logging/logging.h" diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h index d3d6da8364b6..2adf346332c6 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h @@ -2098,7 +2098,7 @@ struct NnApi { * @param executionCallback The execution callback to set. * @param callbackContext The context to be passed to the callbacks when they * are invoked. The context object may be used by multiple threads - * simulatenously, so it must be thread-safe. + * simultaneously, so it must be thread-safe. */ void (*SL_ANeuralNetworksDiagnostic_registerCallbacks)( ANeuralNetworksDiagnosticCompilationFinishedCallback compilationCallback, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index 1713f201c9d6..5283e9a559cd 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -127,7 +127,8 @@ Status ConvOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, int32_t elem_data_type = 0; ORT_RETURN_IF_ERROR(utils::GetOnnxTensorElemDataType(input_1.node_arg, elem_data_type)); - const bool is_signed_type = (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) || + const bool is_signed_type = (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) || + (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) || (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT16); ORT_RETURN_IF_NOT(is_signed_type, "Conv weights must be of a signed quantized type if quantized per-channel"); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index c8bd31bde77d..f44efb1eba6d 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -13,7 +13,7 @@ // #include "GPU/QnnGpuCommon.h" #include "DSP/QnnDspCommon.h" #include "HTP/QnnHtpCommon.h" -#include "core/common/gsl.h" +#include #include "core/framework/endian_utils.h" #include "core/common/logging/capture.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 3a8a8af17b90..f85cdc401a15 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -5,6 +5,8 @@ #include #include #include +#include +#include #include "qnn_model_wrapper.h" #include "core/common/safeint.h" @@ -313,7 +315,8 @@ bool QnnModelWrapper::GetOnnxShape(const NodeArg& node_arg, std::vector& zero_points) const { + /*out*/ std::vector& zero_points, + /*out*/ int32_t& onnx_data_type) const { const auto& graph_initializers = GetInitializerTensors(); auto iter = graph_initializers.find(initializer_name); ORT_RETURN_IF(iter == graph_initializers.end(), "Unable to find initializer for zero-point(s): ", @@ -323,13 +326,14 @@ Status QnnModelWrapper::UnpackZeroPoints(const std::string& initializer_name, ORT_RETURN_IF_NOT(zp_tensor_proto->has_data_type(), "Expected zero-point initializer ", initializer_name.c_str(), " to have a proto data type."); - const int32_t onnx_data_type = zp_tensor_proto->data_type(); + onnx_data_type = zp_tensor_proto->data_type(); std::vector initializer_bytes; ORT_RETURN_IF_ERROR(UnpackInitializerData(*zp_tensor_proto, initializer_bytes)); switch (onnx_data_type) { // QNN use -offset for some reason + case ONNX_NAMESPACE::TensorProto_DataType_INT4: // INT4 zero-points are unpacked as 8-bit values for QNN case ONNX_NAMESPACE::TensorProto_DataType_INT8: { auto int8_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); std::transform(int8_span.begin(), int8_span.end(), std::back_inserter(zero_points), @@ -338,6 +342,7 @@ Status QnnModelWrapper::UnpackZeroPoints(const std::string& initializer_name, }); break; } + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: // UINT4 zero-points are unpacked as 8-bit values for QNN case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { auto uint8_span = ReinterpretAsSpan(gsl::make_span(initializer_bytes)); std::transform(uint8_span.begin(), uint8_span.end(), std::back_inserter(zero_points), @@ -584,10 +589,36 @@ void QnnModelWrapper::GetGraphInputOutputTensorWrapper(const std::vector& unpacked_tensor) const { if (initializer.data_location() == onnx::TensorProto_DataLocation_EXTERNAL) { - return onnxruntime::utils::UnpackInitializerData(initializer, graph_viewer_.ModelPath(), unpacked_tensor); + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, graph_viewer_.ModelPath(), + unpacked_tensor)); + } else { + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor)); + } + + int32_t onnx_data_type = initializer.data_type(); + + // If this is an int4, we need to unpack it because QNN treats int4 as a full int8. + if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) { + TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); + const size_t num_elems = shape.Size(); + std::vector packed_int4_bytes = std::move(unpacked_tensor); + unpacked_tensor = std::vector(num_elems); + + auto dst = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), unpacked_tensor.size()); + auto src = gsl::make_span(reinterpret_cast(packed_int4_bytes.data()), packed_int4_bytes.size()); + ORT_RETURN_IF_NOT(Int4x2::Unpack(dst, src), "Failed to unpack Tensor for QNN"); + } else if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); + const size_t num_elems = shape.Size(); + std::vector packed_int4_bytes = std::move(unpacked_tensor); + unpacked_tensor = std::vector(num_elems); + + auto dst = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), unpacked_tensor.size()); + auto src = gsl::make_span(reinterpret_cast(packed_int4_bytes.data()), packed_int4_bytes.size()); + ORT_RETURN_IF_NOT(UInt4x2::Unpack(dst, src), "Failed to unpack Tensor for QNN"); } - return onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor); + return Status::OK(); } } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index 0705a1d1b8f5..9ab122b7f8e2 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -216,7 +216,9 @@ class QnnModelWrapper { Status UnpackScales(const std::string& initializer_name, std::vector& scales) const; // Unpack zero-points from initializer and convert to int32_t (1 zero-point for per-tensor, > 1 for per-channel). - Status UnpackZeroPoints(const std::string& initializer_name, std::vector& zero_points) const; + Status UnpackZeroPoints(const std::string& initializer_name, + /*out*/ std::vector& zero_points, + /*out*/ int32_t& onnx_data_type) const; // Checks if a tensor in the ONNX graph is per-channel quantized. Status IsPerChannelQuantized(const onnxruntime::NodeUnitIODef& io_def, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc index 401d403c15b0..2d22c3c1b822 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc @@ -9,6 +9,9 @@ #include "QnnTypes.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" +#define ALIGN_PTR_UP(ptr, align, type) \ + reinterpret_cast((reinterpret_cast(ptr) + (align)-1) & ~((align)-1)) + namespace onnxruntime { namespace qnn { @@ -38,9 +41,10 @@ QnnQuantParamsWrapper QnnQuantParamsWrapper::Copy() const { return QnnQuantParamsWrapper(*this); } +// Initializes by copying from a Qnn_QuantizeParams_t. Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) { - if (scale_offset_data_) { - scale_offset_data_.reset(nullptr); + if (per_channel_data_) { + per_channel_data_.reset(nullptr); params_ = QNN_QUANTIZE_PARAMS_INIT; } @@ -51,6 +55,7 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) { switch (params.quantizationEncoding) { case QNN_QUANTIZATION_ENCODING_SCALE_OFFSET: + case QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET: params_ = params; break; case QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: { @@ -63,15 +68,49 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) { const uint32_t num_elems = params.axisScaleOffsetEncoding.numScaleOffsets; if (num_elems > 0) { - scale_offset_data_ = std::make_unique(num_elems); - gsl::span src_span(params.axisScaleOffsetEncoding.scaleOffset, num_elems); - std::copy(src_span.begin(), src_span.end(), scale_offset_data_.get()); - params_.axisScaleOffsetEncoding.scaleOffset = scale_offset_data_.get(); + const size_t num_bytes = num_elems * sizeof(Qnn_ScaleOffset_t); + constexpr std::uintptr_t align = alignof(Qnn_ScaleOffset_t); + per_channel_data_ = std::make_unique(num_bytes + align); + Qnn_ScaleOffset_t* aligned_dst = ALIGN_PTR_UP(per_channel_data_.get(), align, Qnn_ScaleOffset_t*); + + std::memcpy(aligned_dst, params.axisScaleOffsetEncoding.scaleOffset, num_bytes); + params_.axisScaleOffsetEncoding.scaleOffset = aligned_dst; } else { params_.axisScaleOffsetEncoding.scaleOffset = nullptr; } break; } + case QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: { + const uint32_t num_elems = params.bwAxisScaleOffsetEncoding.numElements; + + params_.encodingDefinition = params.encodingDefinition; + params_.quantizationEncoding = params.quantizationEncoding; + params_.bwAxisScaleOffsetEncoding.axis = params.bwAxisScaleOffsetEncoding.axis; + params_.bwAxisScaleOffsetEncoding.bitwidth = params.bwAxisScaleOffsetEncoding.bitwidth; + params_.bwAxisScaleOffsetEncoding.numElements = num_elems; + + // Deep copy the scales[] and offsets[] arrays + if (num_elems > 0) { + const size_t num_scale_bytes = num_elems * sizeof(float); + const size_t num_zp_bytes = num_elems * sizeof(int32_t); + const size_t num_bytes = num_scale_bytes + num_zp_bytes; + constexpr std::uintptr_t align = alignof(float); + static_assert(alignof(float) == alignof(int32_t)); + + per_channel_data_ = std::make_unique(num_bytes + align); + char* scales_begin = ALIGN_PTR_UP(per_channel_data_.get(), align, char*); + char* zps_begin = scales_begin + num_scale_bytes; + + std::memcpy(scales_begin, params.bwAxisScaleOffsetEncoding.scales, num_scale_bytes); + std::memcpy(zps_begin, params.bwAxisScaleOffsetEncoding.offsets, num_zp_bytes); + params_.bwAxisScaleOffsetEncoding.scales = reinterpret_cast(scales_begin); + params_.bwAxisScaleOffsetEncoding.offsets = reinterpret_cast(zps_begin); + } else { + params_.bwAxisScaleOffsetEncoding.scales = nullptr; + params_.bwAxisScaleOffsetEncoding.offsets = nullptr; + } + break; + } default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported QNN quantization encoding: ", params.quantizationEncoding); } @@ -79,11 +118,13 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) { return Status::OK(); } +// Initialize this object from a (potentially) quantized ONNX tensor. +// QnnModelWrapper provides utilities for unpacking scale and zero-point ONNX initializers. Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, const NodeUnitIODef& io_def) { const std::optional& ort_quant_params = io_def.quant_param; - if (scale_offset_data_) { - scale_offset_data_.reset(nullptr); + if (per_channel_data_) { + per_channel_data_.reset(nullptr); params_ = QNN_QUANTIZE_PARAMS_INIT; } @@ -98,17 +139,25 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackScales(ort_quant_params->scale.Name(), scales)); + bool is_int4_type = false; + if (ort_quant_params->zero_point != nullptr) { - ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackZeroPoints(ort_quant_params->zero_point->Name(), zero_points)); + int32_t onnx_tp_type = 0; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackZeroPoints(ort_quant_params->zero_point->Name(), zero_points, + onnx_tp_type)); + + is_int4_type = (onnx_tp_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) || + (onnx_tp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4); } const bool is_per_tensor = scales.size() == 1; - if (is_per_tensor) { + // QNN uses different structs to represent quantization parameters depending on + // - per-tensor vs per-channel + // - int4 vs not int4 + if (is_per_tensor && !is_int4_type) { params_.encodingDefinition = QNN_DEFINITION_DEFINED; params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; - - // Parse scale & zero_point params_.scaleOffsetEncoding.scale = scales[0]; if (ort_quant_params->zero_point != nullptr) { @@ -117,8 +166,62 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con } else { params_.scaleOffsetEncoding.offset = 0; } - } else { - // Per-channel quantization. + } else if (is_per_tensor && is_int4_type) { + params_.encodingDefinition = QNN_DEFINITION_DEFINED; + params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET; + params_.bwScaleOffsetEncoding.bitwidth = 4; + params_.bwScaleOffsetEncoding.scale = scales[0]; + + if (ort_quant_params->zero_point != nullptr) { + ORT_RETURN_IF_NOT(zero_points.size() == 1, "Expected one zero-point value"); + params_.bwScaleOffsetEncoding.offset = zero_points[0]; + } else { + params_.bwScaleOffsetEncoding.offset = 0; + } + } else if (!is_per_tensor && is_int4_type) { + const auto* io_shape = io_def.node_arg.Shape(); + ORT_RETURN_IF(io_shape == nullptr, "Input/output tensor proto must have a shape"); + const int32_t io_rank = io_shape->dim_size(); + + constexpr int64_t DEFAULT_QDQ_AXIS = 1; + int64_t axis = ort_quant_params->axis.value_or(DEFAULT_QDQ_AXIS); + if (axis < 0) { + axis += io_rank; + } + ORT_RETURN_IF_NOT(axis >= 0 && axis < io_rank, + "Quantization axis must be within the range [0, rank - 1]"); + + const size_t num_elems = scales.size(); + const bool no_zero_points = zero_points.empty(); + ORT_RETURN_IF_NOT(num_elems > 1, "Expected more than one scale value"); + ORT_RETURN_IF_NOT(no_zero_points || zero_points.size() == num_elems, + "Expected the same number of zero-points and scales for per-channel quantization"); + + params_.encodingDefinition = QNN_DEFINITION_DEFINED; + params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET; + params_.bwAxisScaleOffsetEncoding.axis = static_cast(*(ort_quant_params->axis)); + params_.bwAxisScaleOffsetEncoding.bitwidth = 4; + params_.bwAxisScaleOffsetEncoding.numElements = static_cast(num_elems); + + const size_t num_scale_bytes = num_elems * sizeof(float); + const size_t num_zp_bytes = num_elems * sizeof(int32_t); + const size_t num_bytes = num_scale_bytes + num_zp_bytes; + constexpr std::uintptr_t align = alignof(float); + per_channel_data_ = std::make_unique(num_bytes + align); + + char* scales_begin = ALIGN_PTR_UP(per_channel_data_.get(), align, char*); + char* zps_begin = scales_begin + num_scale_bytes; + gsl::span scales_span(reinterpret_cast(scales_begin), num_elems); + gsl::span zps_span(reinterpret_cast(zps_begin), num_elems); + + for (size_t i = 0; i < num_elems; i++) { + scales_span[i] = scales[i]; + zps_span[i] = no_zero_points ? 0 : zero_points[i]; + } + + params_.bwAxisScaleOffsetEncoding.scales = scales_span.data(); + params_.bwAxisScaleOffsetEncoding.offsets = zps_span.data(); + } else if (!is_per_tensor && !is_int4_type) { const auto* io_shape = io_def.node_arg.Shape(); ORT_RETURN_IF(io_shape == nullptr, "Input/output tensor proto must have a shape"); const int32_t io_rank = io_shape->dim_size(); @@ -140,8 +243,11 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con ORT_RETURN_IF_NOT(no_zero_points || zero_points.size() == num_elems, "Expected the same number of zero-points and scales for per-channel quantization"); - scale_offset_data_ = std::make_unique(num_elems); - gsl::span data_span(scale_offset_data_.get(), num_elems); + const size_t num_bytes = num_elems * sizeof(Qnn_ScaleOffset_t); + constexpr std::uintptr_t align = alignof(Qnn_ScaleOffset_t); + per_channel_data_ = std::make_unique(num_bytes + align); + Qnn_ScaleOffset_t* aligned_dst = ALIGN_PTR_UP(per_channel_data_.get(), align, Qnn_ScaleOffset_t*); + gsl::span data_span(aligned_dst, num_elems); for (size_t i = 0; i < num_elems; i++) { data_span[i].scale = scales[i]; @@ -151,6 +257,8 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con params_.axisScaleOffsetEncoding.axis = static_cast(axis); params_.axisScaleOffsetEncoding.numScaleOffsets = static_cast(num_elems); params_.axisScaleOffsetEncoding.scaleOffset = data_span.data(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected tensor kind for QuantParamsWrapper::Init()"); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h index 3cf04da97a8f..d1f93e5a692b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h @@ -5,7 +5,7 @@ #include #include "QnnTypes.h" #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "core/framework/node_unit.h" namespace onnxruntime { @@ -48,17 +48,17 @@ class QnnQuantParamsWrapper { (include_bw && params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET)); } - bool IsPerChannel(bool include_bw = false) const { + bool IsPerChannel() const { return params_.encodingDefinition == QNN_DEFINITION_DEFINED && (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET || - (include_bw && params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET)); + (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET)); } // Handle transposing of a per-channel quantized tensor. The quantization parameter's axis // must be transposed using the inverse permutation of the Transpose. template Status HandleTranspose(gsl::span perm) { - if (!IsPerChannel(true)) { + if (!IsPerChannel()) { return Status::OK(); } @@ -82,7 +82,7 @@ class QnnQuantParamsWrapper { template Status HandleUnsqueeze(gsl::span orig_shape, gsl::span new_shape) { - if (!IsPerChannel(true)) { + if (!IsPerChannel()) { return Status::OK(); } @@ -134,7 +134,13 @@ class QnnQuantParamsWrapper { private: Qnn_QuantizeParams_t params_; - std::unique_ptr scale_offset_data_; // Stores per-channel scales and offsets + + // Stores arrays of per-channel scales and offsets. Fields in params_ point to this data. + // + // Use an opaque array of bytes because QNN uses different data layouts depending on the quantization encoding: + // - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: array of scale/zp pairs [{scale0, zp0}, {scale1, zp1}, ...] + // - QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: parallel arrays for scales and zps [scale0, ...] [zp0, zp1, ...] + std::unique_ptr per_channel_data_; }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index 19362daee6ca..c2e500b8980a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -43,6 +43,8 @@ size_t GetElementSizeByType(const Qnn_DataType_t& data_type) { } size_t GetElementSizeByType(ONNXTensorElementDataType elem_type) { const static std::unordered_map elem_type_to_size = { + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, sizeof(Int4x2)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, sizeof(UInt4x2)}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, sizeof(int8_t)}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, sizeof(int16_t)}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, sizeof(int32_t)}, @@ -162,6 +164,12 @@ std::ostream& operator<<(std::ostream& out, const Qnn_DataType_t& data_type) { case QNN_DATATYPE_BOOL_8: out << "QNN_DATATYPE_BOOL_8"; break; + case QNN_DATATYPE_SFIXED_POINT_4: + out << "QNN_DATATYPE_SFIXED_POINT_4"; + break; + case QNN_DATATYPE_UFIXED_POINT_4: + out << "QNN_DATATYPE_UFIXED_POINT_4"; + break; default: ORT_THROW("Unknown Qnn Data type"); } @@ -216,6 +224,10 @@ std::ostream& operator<<(std::ostream& out, const Qnn_QuantizeParams_t& quantize if (quantize_params.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { out << " scale=" << quantize_params.scaleOffsetEncoding.scale; out << " offset=" << quantize_params.scaleOffsetEncoding.offset; + } else if (quantize_params.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) { + out << " bitwidth=" << quantize_params.bwScaleOffsetEncoding.bitwidth; + out << " scale=" << quantize_params.bwScaleOffsetEncoding.scale; + out << " offset=" << quantize_params.bwScaleOffsetEncoding.offset; } else if (quantize_params.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { out << " axis=" << quantize_params.axisScaleOffsetEncoding.axis; size_t num_elems = quantize_params.axisScaleOffsetEncoding.numScaleOffsets; @@ -292,7 +304,9 @@ std::ostream& operator<<(std::ostream& out, const Qnn_ClientBuffer_t& client_buf T* data = reinterpret_cast(client_bufer.data); out << " dataSize=" << client_bufer.dataSize; uint32_t count = client_bufer.dataSize / sizeof(T); - count = count > 100 ? 100 : count; // limit to 100 data + const bool truncate = count > 100; + + count = truncate ? 100 : count; // limit to 100 data out << " clientBuf=("; for (uint32_t i = 0; i < count; i++) { if constexpr (sizeof(T) == 1) { @@ -301,7 +315,7 @@ std::ostream& operator<<(std::ostream& out, const Qnn_ClientBuffer_t& client_buf out << data[i] << " "; } } - out << ")"; + out << (truncate ? "..." : "") << ")"; return out; } @@ -319,6 +333,8 @@ std::ostream& operator<<(std::ostream& out, const Qnn_Tensor_t& tensor) { } out << ")"; out << " memType=" << GetQnnTensorMemType(tensor); +// TODO: the code below has compilation errors with the latest ABSL +#if 0 if (GetQnnTensorMemType(tensor) == QNN_TENSORMEMTYPE_RAW) { if (GetQnnTensorDataType(tensor) == QNN_DATATYPE_FLOAT_32) { operator<< (out, GetQnnTensorClientBuf(tensor)); @@ -341,6 +357,7 @@ std::ostream& operator<<(std::ostream& out, const Qnn_Tensor_t& tensor) { operator<< (out, GetQnnTensorClientBuf(tensor)); } } +#endif out << " quantizeParams:" << GetQnnTensorQParams(tensor); return out; } @@ -432,10 +449,12 @@ bool OnnxDataTypeToQnnDataType(const int32_t onnx_data_type, Qnn_DataType_t& qnn }; const std::unordered_map onnx_to_qnn_data_type_quantized = { + {ONNX_NAMESPACE::TensorProto_DataType_INT4, QNN_DATATYPE_SFIXED_POINT_8}, {ONNX_NAMESPACE::TensorProto_DataType_INT8, QNN_DATATYPE_SFIXED_POINT_8}, {ONNX_NAMESPACE::TensorProto_DataType_INT16, QNN_DATATYPE_SFIXED_POINT_16}, {ONNX_NAMESPACE::TensorProto_DataType_INT32, QNN_DATATYPE_SFIXED_POINT_32}, {ONNX_NAMESPACE::TensorProto_DataType_INT64, QNN_DATATYPE_INT_64}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT4, QNN_DATATYPE_UFIXED_POINT_8}, {ONNX_NAMESPACE::TensorProto_DataType_UINT8, QNN_DATATYPE_UFIXED_POINT_8}, {ONNX_NAMESPACE::TensorProto_DataType_UINT16, QNN_DATATYPE_UFIXED_POINT_16}, {ONNX_NAMESPACE::TensorProto_DataType_UINT32, QNN_DATATYPE_UFIXED_POINT_32}, diff --git a/onnxruntime/core/providers/rocm/math/softmax.h b/onnxruntime/core/providers/rocm/math/softmax.h index 49bfddad36b4..57c1fc506807 100644 --- a/onnxruntime/core/providers/rocm/math/softmax.h +++ b/onnxruntime/core/providers/rocm/math/softmax.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/providers/rocm/rocm_kernel.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/rocm/miopen_common.cc b/onnxruntime/core/providers/rocm/miopen_common.cc index 6b01f02ae49b..6b08d392069a 100644 --- a/onnxruntime/core/providers/rocm/miopen_common.cc +++ b/onnxruntime/core/providers/rocm/miopen_common.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "miopen_common.h" -#include "core/common/gsl.h" +#include #include "core/providers/cpu/tensor/utils.h" #include "core/providers/rocm/shared_inc/rocm_call.h" diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc index 6214ec7bc0ea..a2b587a56466 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.cc +++ b/onnxruntime/core/providers/rocm/nn/conv.cc @@ -49,10 +49,10 @@ miopenStatus_t GetWorkspaceSize(miopenHandle_t handle, const MiopenConvState& s, - const miopenConvFwdAlgorithm_t* algo, int n_algo) { + const miopenConvFwdAlgorithm_t* algo, int n_algo, int device_id = 0) { // TODO: get maximum available size from memory arena size_t free, total; - HIP_CALL_THROW(hipMemGetInfo(&free, &total)); + onnxruntime::rocm::hipMemGetInfoAlt(device_id, &free, &total); // Assuming 10% of fragmentation free = static_cast(static_cast(free) * 0.9); size_t max_ws_size = 0; @@ -283,7 +283,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) int algo_count = 1; const ROCMExecutionProvider* rocm_ep = static_cast(this->Info().GetExecutionProvider()); static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT; - size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos) + size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos, rocm_ep->GetDeviceId()) : AlgoSearchWorkspaceSize; IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm( diff --git a/onnxruntime/core/providers/rocm/rocm_call.cc b/onnxruntime/core/providers/rocm/rocm_call.cc index 7974053c3249..ca12720fb3eb 100644 --- a/onnxruntime/core/providers/rocm/rocm_call.cc +++ b/onnxruntime/core/providers/rocm/rocm_call.cc @@ -143,6 +143,8 @@ template Status RocmCall(hiprandStatus_t retCode, const template void RocmCall(hiprandStatus_t retCode, const char* exprString, const char* libName, hiprandStatus_t successCode, const char* msg, const char* file, const int line); template Status RocmCall(hipfftResult retCode, const char* exprString, const char* libName, hipfftResult successCode, const char* msg, const char* file, const int line); template void RocmCall(hipfftResult retCode, const char* exprString, const char* libName, hipfftResult successCode, const char* msg, const char* file, const int line); +template Status RocmCall(rsmi_status_t retCode, const char* exprString, const char* libName, rsmi_status_t successCode, const char* msg, const char* file, const int line); +template void RocmCall(rsmi_status_t retCode, const char* exprString, const char* libName, rsmi_status_t successCode, const char* msg, const char* file, const int line); #ifdef ORT_USE_NCCL template Status RocmCall(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line); diff --git a/onnxruntime/core/providers/rocm/rocm_common.h b/onnxruntime/core/providers/rocm/rocm_common.h index 07b3e252c600..a8ddb8523303 100644 --- a/onnxruntime/core/providers/rocm/rocm_common.h +++ b/onnxruntime/core/providers/rocm/rocm_common.h @@ -10,7 +10,7 @@ #include "core/providers/rocm/shared_inc/rocm_call.h" #include "core/providers/rocm/shared_inc/fast_divmod.h" #include "core/util/math.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace rocm { @@ -67,5 +67,17 @@ inline int warpSizeDynamic() { return deviceProp.warpSize; } +inline void hipMemGetInfoAlt(uint32_t deviceId, size_t* pFree, size_t* pTotal) { + const auto status = hipMemGetInfo(pFree, pTotal); + if (status != hipSuccess) { + size_t usedMemory = 0; + ROCMSMI_CALL_THROW(rsmi_init(0)); + ROCMSMI_CALL_THROW(rsmi_dev_memory_total_get(deviceId, RSMI_MEM_TYPE_VIS_VRAM, pTotal)); + ROCMSMI_CALL_THROW(rsmi_dev_memory_usage_get(deviceId, RSMI_MEM_TYPE_VIS_VRAM, &usedMemory)); + *pFree = *pTotal - usedMemory; + ROCMSMI_CALL_THROW(rsmi_shut_down()); + } +} + } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 76964e1aed93..c1cedd47501a 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -43,7 +43,8 @@ class Memcpy final : public OpKernel { // do we support async copy? // The rocmMemCpyAsync will handle the pinned memory and non-pinned memory, // so we don't need the check here. - auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, + Y->Location().device); ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream())); return Status::OK(); } else { @@ -88,10 +89,12 @@ class Memcpy final : public OpKernel { Y->Reserve(X_size); for (size_t i = 0; i < X_size; ++i) { const Tensor& source_tensor = X->Get(i); - std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); + std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), + alloc); auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(source_tensor.Location().device, target_tensor->Location().device); - ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, *ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, + *ctx->GetComputeStream())); Y->Add(std::move(*target_tensor)); } return Status::OK(); @@ -127,7 +130,8 @@ ONNX_OPERATOR_KERNEL_EX( AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId device_id, size_t gpu_mem_limit, ArenaExtendStrategy arena_extend_strategy, - ROCMExecutionProviderExternalAllocatorInfo external_allocator_info, + ROCMExecutionProviderExternalAllocatorInfo + external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) { if (external_allocator_info.UseExternalAllocator()) { AllocatorCreationInfo default_memory_info( @@ -149,7 +153,8 @@ AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId devi device_id, true, {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)}, + : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), + -1, -1, -1, -1L)}, // make it stream aware true, // enable cross stream sharing? @@ -160,8 +165,11 @@ AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId devi } } -ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, hipStream_t stream, size_t /*gpu_mem_limit*/, - ArenaExtendStrategy /*arena_extend_strategy*/, ROCMExecutionProviderExternalAllocatorInfo /*external_allocator_info*/, +ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, hipStream_t stream, + size_t /*gpu_mem_limit*/, + ArenaExtendStrategy /*arena_extend_strategy*/, + ROCMExecutionProviderExternalAllocatorInfo + /*external_allocator_info*/, OrtArenaCfg* /*default_memory_arena_cfg*/) { HIP_CALL_THROW(hipSetDevice(device_id)); @@ -229,7 +237,8 @@ void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { } ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kRocmExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, + : IExecutionProvider{onnxruntime::kRocmExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, + info.device_id)}, info_{info}, tuning_context_(this, &info_.tunable_op) { HIP_CALL_THROW(hipSetDevice(info_.device_id)); @@ -261,7 +270,7 @@ ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& in size_t free = 0; size_t total = 0; - HIP_CALL_THROW(hipMemGetInfo(&free, &total)); + onnxruntime::rocm::hipMemGetInfoAlt(info_.device_id, &free, &total); OverrideTunableOpInfoByEnv(info_); @@ -313,7 +322,8 @@ ROCMExecutionProvider::PerThreadContext& ROCMExecutionProvider::GetPerThreadCont // get or create a context if (context_state_.retired_context_pool.empty()) { context = std::make_shared(info_.device_id, stream_, info_.gpu_mem_limit, - info_.arena_extend_strategy, info_.external_allocator_info, info_.default_memory_arena_cfg); + info_.arena_extend_strategy, info_.external_allocator_info, + info_.default_memory_arena_cfg); } else { context = context_state_.retired_context_pool.back(); context_state_.retired_context_pool.pop_back(); @@ -357,7 +367,8 @@ Status ROCMExecutionProvider::Sync() const { Status ROCMExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { // always set ROCM device when session::Run() in case it runs in a worker thread HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); - if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured(0)) { + if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && + !GetPerThreadContext().IsGraphCaptured(0)) { LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model"; GetPerThreadContext().CaptureBegin(0); } @@ -471,7 +482,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, LogSoftmax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, + LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, float, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, double, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow); @@ -504,20 +516,32 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Greater); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, float, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, double, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, float, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, double, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, + LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, + LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, + LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, + LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, float, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, double, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, + LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int32_t, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int64_t, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, uint32_t, Add); @@ -573,7 +597,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 10, float, Clip); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Reciprocal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Reciprocal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Reciprocal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, + Reciprocal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Sqrt); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Sqrt); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Sqrt); @@ -587,12 +612,18 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, Erf); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Erf); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, bool, Not); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, double, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, MLFloat16, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, float, LRN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, double, LRN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, MLFloat16, LRN); @@ -600,11 +631,14 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, Conv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Conv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ConvTranspose); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, + ConvTranspose); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, + ConvTranspose); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, float, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, double, AveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, + AveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, GlobalAveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, GlobalAveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalAveragePool); diff --git a/onnxruntime/core/providers/rocm/rocm_pch.h b/onnxruntime/core/providers/rocm/rocm_pch.h index ccd17157ffb4..723b990c8d29 100644 --- a/onnxruntime/core/providers/rocm/rocm_pch.h +++ b/onnxruntime/core/providers/rocm/rocm_pch.h @@ -14,6 +14,7 @@ #include #include #include +#include #ifdef ORT_USE_NCCL #include diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index 88ef666678b3..a739fe0a5d19 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -5,7 +5,7 @@ #include "core/providers/rocm/rocm_provider_factory.h" #include "core/providers/rocm/rocm_provider_factory_creator.h" -#include "core/common/gsl.h" +#include #include "core/providers/rocm/rocm_execution_provider.h" #include "core/providers/rocm/rocm_execution_provider_info.h" diff --git a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h index b6b40666b8bd..253ded1911cb 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h +++ b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h @@ -17,6 +17,7 @@ std::conditional_t RocmCall( #define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define ROCBLAS_CALL(expr) (RocmCall((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__)) +#define ROCMSMI_CALL(expr) (RocmCall((expr), #expr, "ROCMSMI", RSMI_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPSPARSE_CALL(expr) (RocmCall((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPRAND_CALL(expr) (RocmCall((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) @@ -27,6 +28,7 @@ std::conditional_t RocmCall( #define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define ROCBLAS_CALL_THROW(expr) (RocmCall((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__)) +#define ROCMSMI_CALL_THROW(expr) (RocmCall((expr), #expr, "ROCMSMI", RSMI_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPSPARSE_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPRAND_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 7cdfb0ffc19f..590bddabdba5 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -12,7 +12,7 @@ #include #include #include -#include "core/common/gsl.h" +#include #include #include #include diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index be924d6a6826..8a601c156bd0 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -18,7 +18,7 @@ #include "core/providers/cuda/gpu_data_transfer.h" #include "core/session/allocator_adapters.h" #include "cuda_runtime_api.h" -#include "core/common/gsl.h" +#include #include #include #include @@ -287,6 +287,7 @@ void CudaCall(cudaError retCode, const char* exprString, const return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); } +#ifndef USE_CUDA_MINIMAL template <> Status CudaCall(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line) { return g_host->CudaCall_false(retCode, exprString, libName, successCode, msg, file, line); @@ -306,6 +307,7 @@ template <> void CudaCall(cudnnStatus_t retCode, const char* exprString, const char* libName, cudnnStatus_t successCode, const char* msg, const char* file, const int line) { return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); } +#endif #if NV_TENSORRT_MAJOR >= 10 void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, @@ -1119,20 +1121,26 @@ Status BindKernelOutput(Ort::KernelContext& ctx, TensorrtExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, bool has_user_compute_stream, cudaStream_t stream) { if (has_user_compute_stream) { CUDA_CALL_THROW(cudaSetDevice(device_id)); +#ifndef USE_CUDA_MINIMAL ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_))); ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream))); +#else + (void)(stream); +#endif } } TensorrtExecutionProvider::PerThreadContext::~PerThreadContext() { +#ifndef USE_CUDA_MINIMAL if (external_cublas_handle_ != nullptr) { ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(external_cublas_handle_))); } if (external_cudnn_handle_ != nullptr) { ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(external_cudnn_handle_))); } +#endif trt_context_map_.clear(); } @@ -1268,10 +1276,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv if (info.has_user_compute_stream) { external_stream_ = true; stream_ = static_cast(info.user_compute_stream); +#ifndef USE_CUDA_MINIMAL ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_))); ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream_))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream_))); +#endif } std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; @@ -1442,6 +1452,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv if (!ep_context_embed_mode_env.empty()) { ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env); } + // incase the EP context is dumped the engine cache has to be enabled + if (dump_ep_context_model_ && ep_context_embed_mode_ == 0) { + engine_cache_enable_ = true; + } enable_engine_cache_for_ep_context_model(); @@ -1737,8 +1751,10 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { } if (external_stream_) { +#ifndef USE_CUDA_MINIMAL ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(external_cublas_handle_))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(external_cudnn_handle_))); +#endif } if (!external_stream_ && stream_) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index ec140579569b..b58e86237860 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -3,9 +3,13 @@ #pragma once #include +#ifndef USE_CUDA_MINIMAL #include -#include - +#else +typedef void* cudnnHandle_t; +typedef void* cublasHandle_t; +typedef void* cudnnStatus_t; +#endif #include "core/providers/tensorrt/nv_includes.h" #include "core/platform/ort_mutex.h" diff --git a/onnxruntime/core/providers/tvm/tvm_api.cc b/onnxruntime/core/providers/tvm/tvm_api.cc index 4c46ea5ffae7..e9a7d002e77c 100644 --- a/onnxruntime/core/providers/tvm/tvm_api.cc +++ b/onnxruntime/core/providers/tvm/tvm_api.cc @@ -16,7 +16,7 @@ #include #include "core/common/common.h" -#include "core/common/gsl.h" +#include #include "tvm_api.h" diff --git a/onnxruntime/core/providers/vitisai/imp/attr_proto.h b/onnxruntime/core/providers/vitisai/imp/attr_proto.h index f4d56dd618a8..bb2883512037 100644 --- a/onnxruntime/core/providers/vitisai/imp/attr_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/attr_proto.h @@ -3,7 +3,7 @@ #pragma once #include #include "vaip/my_ort.h" -#include "core/common/gsl.h" +#include namespace vaip { diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 1133751d82d6..e9ae93ded40c 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -123,11 +123,11 @@ static std::string config_to_json_str(const onnxruntime::ProviderOptions& config vaip_core::DllSafe>> compile_onnx_model( const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { #ifndef _WIN32 - auto model_path = graph_viewer.ModelPath().ToPathString(); + auto model_path = graph_viewer.ModelPath().string(); #else using convert_t = std::codecvt_utf8; std::wstring_convert strconverter; - auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().ToPathString()); + auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().string()); #endif if (s_library_vitisaiep.compile_onnx_model_with_options) { return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options)); @@ -218,7 +218,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { auto& logger = logging::LoggingManager::DefaultLogger(); auto& model = const_cast(const_model); auto model_proto = model.ToProto(); - auto file_path = model.MainGraph().ModelPath().ToPathString(); + auto file_path = model.MainGraph().ModelPath().string(); auto local_registries = IOnnxRuntimeOpSchemaRegistryList{model.MainGraph().GetSchemaRegistry()}; auto ret = Model::Create(std::move(*model_proto), file_path, &local_registries, logger); auto status = ret->MainGraph().Resolve(); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_gsl.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_gsl.h index f6831604d883..def522b8a3a0 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_gsl.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_gsl.h @@ -3,7 +3,7 @@ #pragma once #ifdef ONNXRUNTIME_VITISAI_EP_STUB -#include "core/common/gsl.h" +#include #else #include #endif diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/activation_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/activation_op_builder.h index 9a59d90365f6..7285e4d02c8e 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/activation_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/activation_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.cc index 894bf8e4444f..653940fa5e61 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.cc @@ -100,6 +100,16 @@ bool BaseOpBuilder::HasSupportedInputOutputs(const InitializedTensorSet& initial } } for (const auto& output : node_unit.Outputs()) { + for (const auto& dim : output.node_arg.Shape()->dim()) { + if (!dim.has_dim_value()) { + LOGS_DEFAULT(WARNING) << "Dynamic shape is not supported for now, for output:" << output.node_arg.Name(); + return false; + } + if (dim.dim_value() == 0 && output.node_arg.Shape()->dim_size() > 1) { + LOGS_DEFAULT(WARNING) << "Zero in shape is not supported for now, for output:" << output.node_arg.Name(); + return false; + } + } if (output.quant_param.has_value()) { if (!has_supported_shape(output.quant_param->scale, node_unit.Name(), node_unit.OpType())) return false; diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h index c0cf3365f46e..692dbc833f0a 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h @@ -40,7 +40,7 @@ class BaseOpBuilder : public IOpBuilder { bool IsSupported(const onnxruntime::GraphViewer& graph_viewer, const NodeUnit& node_unit) const override; bool BuildOp(vsi::npu::GraphEP* graph_ep, - const onnxruntime::GraphViewer& graph_viewer, const NodeUnit& node_unit); + const onnxruntime::GraphViewer& graph_viewer, const NodeUnit& node_unit) override; virtual bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, const Node* node) const { return true; diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h index 6579f0ca9045..57b080acdce4 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.h index 368cb092657c..9e5570645126 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include "core/providers/vsinpu/builders/impl/base_op_builder.h" diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/concat_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/concat_op_builder.h index 4d3fc658b7be..87c8c70247be 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/concat_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/concat_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h index d44e1ce1799c..3ed432c2efa1 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/dequantize_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/dequantize_op_builder.h index 50b295f2fb53..5551ff2a0d31 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/dequantize_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/dequantize_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/dropout_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/dropout_op_builder.h new file mode 100644 index 000000000000..bd8fa0bc587c --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/dropout_op_builder.h @@ -0,0 +1,81 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class DropoutOpBuilder : public BaseOpBuilder { + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + if (node_unit.Inputs().size() > 2) { + const ONNX_NAMESPACE::TensorProto* tensor_proto = + initializers.at(node_unit.Inputs()[2].node_arg.Name()); + std::vector training_mode(1); + auto status = onnxruntime::utils::UnpackTensor( + *tensor_proto, + tensor_proto->has_raw_data() ? tensor_proto->raw_data().data() : nullptr, + tensor_proto->has_raw_data() ? tensor_proto->raw_data().size() : 0, + training_mode.data(), training_mode.size()); + if (!status.IsOK()) { + LOGS_DEFAULT(ERROR) << "Failed to get data training mode tensor."; + return false; + } + if (training_mode[0] == true) { + LOGS_DEFAULT(WARNING) << "Only support inference typed dropout now."; + return false; + } + } + if (node_unit.Inputs().size() > 1) return false; + return true; + } + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + NodeAttrHelper helper(*node); + if (helper.HasAttr("seed")) { + LOGS_DEFAULT(WARNING) << "Not support seed in Dropout op."; + return false; + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating DropOut Op."; + auto op = graph_ep->GetGraph()->CreateOperation(1.0); + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h index 89809a451334..df2e429f58b2 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h @@ -22,6 +22,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include @@ -29,28 +30,28 @@ namespace onnxruntime { namespace vsi { namespace npu { -#define ELEMENTWISE_OP_BUILDER(onnx_op_type, vsinpu_op_kind) \ - class onnx_op_type##OpBuilder : public BaseOpBuilder { \ - bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, \ - const Node* node) const override { \ - for (auto input : node->InputDefs()) { \ - if (*input->Type() == "tensor(int64)") { \ - LOGS_DEFAULT(WARNING) << "Int64 type is not suppoted as elementwise operation input."; \ - return false; \ - } \ - } \ - return true; \ - } \ - bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, \ - std::vector>& inputs, \ - std::vector>& outputs, \ - const NodeUnit& node_unit) override { \ - LOGS_DEFAULT(INFO) << "Creating " << #onnx_op_type << " Op"; \ - auto op = graph_ep->GetGraph() -> CreateOperation(); \ - (*op).BindInputs(inputs).BindOutputs(outputs); \ - return true; \ - ; \ - } \ +#define ELEMENTWISE_OP_BUILDER(onnx_op_type, vsinpu_op_kind) \ + class onnx_op_type##OpBuilder : public BaseOpBuilder { \ + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, \ + const Node* node) const override { \ + for (auto input : node->InputDefs()) { \ + if (*input->Type() == "tensor(int64)") { \ + LOGS_DEFAULT(WARNING) << "Int64 type is not supported as elementwise operation input."; \ + return false; \ + } \ + } \ + return true; \ + } \ + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, \ + std::vector>& inputs, \ + std::vector>& outputs, \ + const NodeUnit& node_unit) override { \ + LOGS_DEFAULT(INFO) << "Creating " << #onnx_op_type << " Op"; \ + auto op = graph_ep->GetGraph() -> CreateOperation(); \ + (*op).BindInputs(inputs).BindOutputs(outputs); \ + return true; \ + ; \ + } \ }; ELEMENTWISE_OP_BUILDER(Add, Add); diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/flatten_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/flatten_op_builder.h index dfb0bb9c1b99..fa8ccfc96a73 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/flatten_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/flatten_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h index 0325b68ae0ad..bd91bbd81e1f 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/gemm_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/gemm_op_builder.h index 6f2c590b862b..07388f862ae1 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/gemm_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/gemm_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/matmul_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/matmul_op_builder.h index 8cdf72906b64..f439736bde09 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/matmul_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/matmul_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/norm_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/norm_op_builder.h index 997163c6b1a6..085850e8a8b5 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/norm_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/norm_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/pool_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/pool_op_builder.h index 7cfa9faf6848..34339875cd6e 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/pool_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/pool_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/qlinear_binary_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/qlinear_binary_op_builder.h index def37b1ec101..0c09d2183420 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/qlinear_binary_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/qlinear_binary_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconcat_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconcat_op_builder.h index dc51e99730c1..4789c0685592 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconcat_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconcat_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconv_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconv_op_builder.h index 8b63a07e17f1..a5ad42d3c41f 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconv_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconv_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/qlinearmatmul_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearmatmul_op_builder.h index 7447c8b6b0b9..8b53d2d64296 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/qlinearmatmul_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearmatmul_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/quantize_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/quantize_op_builder.h index 63ae491909bd..6c0fd44464e1 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/quantize_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/quantize_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h index 3b0a282c5de8..d2e16026b712 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/resize_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/resize_op_builder.h index 8857efe3537e..db3ccb153ec3 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/resize_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/resize_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include @@ -136,8 +137,10 @@ class ResizeOpBuilder : public BaseOpBuilder { for (int i = 0; i < input_shape.size(); i++) { out_shape[i] = input_shape[i] * scales[input_shape.size() - 1 - i]; } + target_h = static_cast(out_shape[1]); + target_w = static_cast(out_shape[0]); op = graph_ep->GetGraph()->CreateOperation(resize_type, 0, align_corners, - half_pixel_center, out_shape[1], out_shape[0]); + half_pixel_center, target_h, target_w); } } diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/slice_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/slice_op_builder.h new file mode 100644 index 000000000000..8457f91e02cb --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/slice_op_builder.h @@ -0,0 +1,148 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +enum SliceInputs { + data = 0, + starts = 1, + ends = 2, + axes = 3, + steps = 4 +}; + +class SliceOpBuilder : public BaseOpBuilder { + public: + int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 10; } + + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + for (size_t i = 0; i < node_unit.Inputs().size(); ++i) { + const auto& iodef = node_unit.Inputs()[i]; + if (!util::IsTypeSupported(&iodef.node_arg) || + (i == 0 && *iodef.node_arg.Type() == "tensor(int64)") || + (i != 0 && !Contains(initializers, iodef.node_arg.Name()))) { + return false; + } + } + return true; + } + + template + void CopyTensorDataToVector(const std::shared_ptr& tensor, std::vector& vec) { + std::vector data(tensor->GetSpec().GetElementNum()); + tensor->CopyDataFromTensor(data.data()); + std::transform(data.begin(), data.end(), vec.begin(), [](T val) { + return static_cast(std::clamp(val, static_cast(std::numeric_limits::min()), + static_cast(std::numeric_limits::max()))); + }); + } + + void ProcessAxes(const std::vector>& inputs, + int dims, bool full_axes, + std::vector& timvx_starts, + std::vector& timvx_ends, + std::vector& timvx_strides) { + auto num_elements = full_axes ? dims : inputs[SliceInputs::axes]->GetSpec().GetElementNum(); + std::vector onnx_starts(num_elements), onnx_ends(num_elements), + onnx_axes(num_elements), onnx_strides(num_elements, 1); + + auto data_type = inputs[SliceInputs::starts]->GetSpec().GetDataType(); + std::iota(onnx_axes.begin(), onnx_axes.end(), 0); + if (data_type == tim::vx::DataType::INT64) { + CopyTensorDataToVector(inputs[SliceInputs::starts], onnx_starts); + CopyTensorDataToVector(inputs[SliceInputs::ends], onnx_ends); + if (inputs.size() > 3) { + CopyTensorDataToVector(inputs[SliceInputs::axes], onnx_axes); + if (inputs.size() == 5) { + CopyTensorDataToVector(inputs[SliceInputs::steps], onnx_strides); + } + } + } else { + CopyTensorDataToVector(inputs[SliceInputs::starts], onnx_starts); + CopyTensorDataToVector(inputs[SliceInputs::ends], onnx_ends); + if (inputs.size() > 3) { + CopyTensorDataToVector(inputs[SliceInputs::axes], onnx_axes); + if (inputs.size() == 5) { + CopyTensorDataToVector(inputs[SliceInputs::steps], onnx_strides); + } + } + } + + if (!full_axes) { + for (auto& axis : onnx_axes) { + axis = HandleNegativeAxis(axis, inputs[0]->GetShape().size()); + } + } + + for (int i = 0; i < dims; ++i) { + if (full_axes || std::find(onnx_axes.begin(), onnx_axes.end(), i) != onnx_axes.end()) { + int axes_index = std::distance(onnx_axes.begin(), std::find(onnx_axes.begin(), onnx_axes.end(), i)); + timvx_starts[i] = onnx_starts[axes_index]; + timvx_ends[i] = onnx_ends[axes_index]; + if (inputs.size() == 5) { + timvx_strides[i] = onnx_strides[axes_index]; + } + } else if (!full_axes) { + timvx_starts[i] = 0; + timvx_ends[i] = inputs[SliceInputs::data]->GetShape()[dims - i - 1]; + } + } + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Slice Op."; + auto total_dims = inputs[SliceInputs::data]->GetShape().size(); + bool full_axes = inputs.size() <= 3 || (inputs[SliceInputs::axes]->GetSpec().GetElementNum() == total_dims); + std::vector timvx_starts(total_dims), timvx_ends(total_dims), timvx_strides(total_dims, 1); + + ProcessAxes(inputs, total_dims, full_axes, timvx_starts, timvx_ends, timvx_strides); + + std::reverse(timvx_starts.begin(), timvx_starts.end()); + std::reverse(timvx_ends.begin(), timvx_ends.end()); + std::reverse(timvx_strides.begin(), timvx_strides.end()); + + auto op = graph_ep->GetGraph()->CreateOperation( + timvx_starts, timvx_ends, timvx_strides, 0, 0, 0); + op->BindInput(inputs[SliceInputs::data]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/softmax_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/softmax_op_builder.h index dad10c1a0251..055b3680db80 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/softmax_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/softmax_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include @@ -67,7 +68,8 @@ class SoftmaxOpBuilder : public BaseOpBuilder { auto reshaped_spec = inputs[0]->GetSpec().AsTransientSpec().SetShape( std::vector{first_dim, last_dim}); auto reshaped_input = graph_ep->GetGraph()->CreateTensor(reshaped_spec); - auto reshaped_output = graph_ep->GetGraph()->CreateTensor(inputs[0]->GetSpec().AsTransientSpec()); + auto reshaped_output = graph_ep->GetGraph()->CreateTensor( + inputs[0]->GetSpec().AsTransientSpec()); auto reshape_input_op = graph_ep->GetGraph()->CreateOperation( std::vector{first_dim, last_dim}); diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/squeeze_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/squeeze_op_builder.h index 2e1837384618..bc90ae1ff26a 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/squeeze_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/squeeze_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/tensor_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/tensor_op_builder.h index 427457b521b6..9a53bef2a669 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/tensor_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/tensor_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/tile_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/tile_op_builder.h index d42624c31557..501468251cea 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/tile_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/tile_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/unsqueeze_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/unsqueeze_op_builder.h index c49c93008b25..d7fbaeaa72fc 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/unsqueeze_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/unsqueeze_op_builder.h @@ -21,6 +21,7 @@ * DEALINGS IN THE SOFTWARE. * *****************************************************************************/ +#pragma once #include #include #include diff --git a/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h b/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h index 3a9190d8cb03..27c148c1672c 100644 --- a/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h @@ -51,6 +51,8 @@ #include "impl/unsqueeze_op_builder.h" #include "impl/resize_op_builder.h" #include "impl/cast_op_builder.h" +#include "impl/dropout_op_builder.h" +#include "impl/slice_op_builder.h" namespace onnxruntime { namespace vsi { namespace npu { @@ -108,7 +110,8 @@ static const std::map reg = { REGISTER_OP_BUILDER("Unsqueeze", UnsqueezeOpBuilder), REGISTER_OP_BUILDER("Resize", ResizeOpBuilder), REGISTER_OP_BUILDER("Cast", CastOpBuilder), - + REGISTER_OP_BUILDER("Dropout", DropoutOpBuilder), + REGISTER_OP_BUILDER("Slice", SliceOpBuilder) #undef REGISTER_OP_BUILDER }; diff --git a/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch index a9d02765cf34..2176ff559c68 100644 --- a/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch +++ b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch @@ -1,34 +1,35 @@ diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake -index e0ccc504d7..6c5aa6ea53 100644 +index 304aa77f54..5c22b7097b 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake -@@ -335,7 +335,7 @@ else() - ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp +@@ -354,7 +354,7 @@ else() ) + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") - if (NOT APPLE) + if (NOT APPLE AND NOT onnxruntime_USE_VSINPU) set(mlas_platform_srcs ${mlas_platform_srcs} ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h -index fd6b3df934..f81f1c42b6 100644 +index cdfd283899..678a055b24 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h -@@ -79,6 +79,7 @@ Abstract: +@@ -82,6 +82,9 @@ Abstract: #if (!defined(_MSC_VER)) || (_MSC_VER >= 1930) #if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) +#if !defined(USE_VSINPU) ++// Had to tempory disable fp16 under VeriSilicon ARM64 to avoid ++// conflict of compilation flag. #if !defined(__APPLE__) // Had to temporary disable fp16 under APPLE ARM64, as compiling // the source files require a hardware specific compilation flag. -@@ -87,7 +88,8 @@ Abstract: +@@ -90,6 +93,7 @@ Abstract: #define MLAS_F16VEC_INTRINSICS_SUPPORTED --#endif // -+#endif +#endif // + #endif // #endif // ARM64 #endif // Visual Studio 16 or earlier does not support fp16 intrinsic diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc index e51b0713ea41..bbf8255ac294 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc @@ -113,7 +113,9 @@ void GraphEP::UpdateTensorMap(const std::string& name, const std::shared_ptr GraphEP::ConstructNodeIO(const std::shared_ptr& op, std::vector input_arg, std::vector output_arg) { +std::shared_ptr GraphEP::ConstructNodeIO(const std::shared_ptr& op, + std::vector input_arg, + std::vector output_arg) { auto info = std::make_shared(); info->op_ = op; std::vector input_names, output_names; @@ -173,7 +175,6 @@ std::shared_ptr GraphEP::MapTIMVXTensor( const auto& arg = nudef.node_arg; if (tensors_.end() != tensors_.find(nudef.node_arg.Name())) { - // if (!quant_param.has_value() || quant_param.has_value() && tensors_[arg.Name()]->GetSpec().GetQuantization().Type() != tim::vx::QuantType::NONE) return tensors_.find(arg.Name())->second; } auto shape = vsi::npu::util::OnnxShapeToTIMVXShape(vsi::npu::util::GetTensorShape(arg)); @@ -190,16 +191,18 @@ std::shared_ptr GraphEP::MapTIMVXTensor( std::optional> scales; std::optional> zps; if (nudef.quant_param.has_value()) { - util::GetQuantizationScaleAndZeroPoint(graph_viewer_.GetAllInitializedTensors(), + util::GetQuantizationScaleAndZeroPoint(graph_viewer_, nudef, node_unit.ModelPath(), scale, zp, scales, zps); } else { auto target_nodeunit = all_quantized_op_inputs_[arg.Name()][0]; auto qinput = all_quantized_op_inputs_[arg.Name()][0]->Inputs(); - auto it = std::find_if(qinput.begin(), qinput.end(), [&arg](const NodeUnitIODef& nud) { return nud.node_arg.Name() == arg.Name(); }); + auto it = std::find_if(qinput.begin(), qinput.end(), [&arg](const NodeUnitIODef& nud) { + return nud.node_arg.Name() == arg.Name(); + }); bool is_conv_bias = std::distance(qinput.begin(), it) == 2; if (!is_conv_bias || it->quant_param.has_value()) { - util::GetQuantizationScaleAndZeroPoint(graph_viewer_.GetAllInitializedTensors(), + util::GetQuantizationScaleAndZeroPoint(graph_viewer_, *it, target_nodeunit->ModelPath(), scale, zp, scales, zps); } else if (!it->quant_param.has_value()) { @@ -209,11 +212,12 @@ std::shared_ptr GraphEP::MapTIMVXTensor( std::optional> in_zps, w_zps; // onnx defines conv bias with non quantization, but it must be quantized in VSINPU support - // The bias scale is set as input_scale * weight_scale if per layer quantized, input_scale* weight_scale[i] if per channel quantized - util::GetQuantizationScaleAndZeroPoint(graph_viewer_.GetAllInitializedTensors(), + // The bias scale is set as input_scale * weight_scale if per layer quantized, + // otherwise input_scale* weight_scale[i] if per channel quantized + util::GetQuantizationScaleAndZeroPoint(graph_viewer_, qinput[0], target_nodeunit->ModelPath(), in_scale, in_zp, in_scales, in_zps); - util::GetQuantizationScaleAndZeroPoint(graph_viewer_.GetAllInitializedTensors(), + util::GetQuantizationScaleAndZeroPoint(graph_viewer_, qinput[1], target_nodeunit->ModelPath(), w_scale, w_zp, w_scales, w_zps); scale = in_scale * w_scale; diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h index bd0f377b820b..49344770d060 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h @@ -82,7 +82,8 @@ class GraphEP { void UpdateTensorMap(const std::string& name, const std::shared_ptr& dst_tensor); - std::shared_ptr ConstructNodeIO(const std::shared_ptr& op, std::vector input_arg, std::vector output_arg); + std::shared_ptr ConstructNodeIO(const std::shared_ptr& op, + std::vector input_arg, std::vector output_arg); bool BindTensors(const std::shared_ptr& nodeio_info); diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc index 7444dcfec09a..466fe1f82461 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc @@ -137,9 +137,10 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie std::for_each(result.begin(), result.end(), [&graph_viewer](auto& capability) { if (capability && capability->sub_graph && capability->sub_graph->GetMetaDef()) { const auto* meta_def = capability->sub_graph->GetMetaDef(); - bool has_any_non_constant_inputs = std::any_of(meta_def->inputs.begin(), meta_def->inputs.end(), [&graph_viewer](const auto& input) { - return !graph_viewer.IsConstantInitializer(input, true); - }); + bool has_any_non_constant_inputs = std::any_of(meta_def->inputs.begin(), + meta_def->inputs.end(), [&graph_viewer](const auto& input) { + return !graph_viewer.IsConstantInitializer(input, true); + }); // ALL inputs are constant if (!has_any_non_constant_inputs) { @@ -184,7 +185,8 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, const auto tensor_info = onnx_input_tensor.GetTensorTypeAndShapeInfo(); auto origin_tensor = graph_ep->GetGraphInputs()[i]->tensor; - origin_tensor->CopyDataToTensor(onnx_input_tensor.GetTensorRawData(), vsi::npu::util::GetTensorBytes(tensor_info)); + origin_tensor->CopyDataToTensor(onnx_input_tensor.GetTensorRawData(), + vsi::npu::util::GetTensorBytes(tensor_info)); j++; } } diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_util.cc b/onnxruntime/core/providers/vsinpu/vsinpu_util.cc index 8008ec1f436a..5d2f701ceac2 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_util.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_util.cc @@ -412,7 +412,7 @@ bool HasValidBinaryOpQuantizedInputTypes(const NodeUnit& node_unit) { } void GetQuantizationScaleAndZeroPoint( - const InitializedTensorSet& initializers, const NodeUnitIODef& io_def, const Path& model_path, + const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const std::filesystem::path& model_path, float& scale, int32_t& zero_point, std::optional>& pcq_scales, std::optional>& pcq_zps) { scale = 0.0f; @@ -421,7 +421,11 @@ void GetQuantizationScaleAndZeroPoint( const auto& quant_param = *io_def.quant_param; { // get the scale const auto& name = quant_param.scale.Name(); - Initializer unpacked_tensor(*initializers.at(name), model_path); + const auto* s = graph_viewer.GetConstantInitializer(name); + if (!s) { + LOGS_DEFAULT(ERROR) << name + " is not a constant initializer"; + }; + Initializer unpacked_tensor(*s, model_path); scale = unpacked_tensor.DataAsSpan()[0]; // per channel quantized handling @@ -434,12 +438,18 @@ void GetQuantizationScaleAndZeroPoint( if (quant_param.zero_point) { // get the zero point if it exists const auto& name = quant_param.zero_point->Name(); - Initializer unpacked_tensor(*initializers.at(name), model_path); + const auto* s = graph_viewer.GetConstantInitializer(name); + if (!s) { + LOGS_DEFAULT(ERROR) << name + " is not a constant initializer"; + }; + Initializer unpacked_tensor(*s, model_path); bool is_i8_zp = unpacked_tensor.data_type() == onnx::TensorProto_DataType_INT8; // some qdq conv bias is int32 quantized bool is_int32_zp = unpacked_tensor.data_type() == onnx::TensorProto_DataType_INT32; - zero_point = is_i8_zp ? static_cast(unpacked_tensor.DataAsSpan()[0]) : is_int32_zp ? static_cast(unpacked_tensor.DataAsSpan()[0]) - : static_cast(unpacked_tensor.DataAsByteSpan()[0]); + zero_point = is_i8_zp + ? static_cast(unpacked_tensor.DataAsSpan()[0]) + : is_int32_zp ? static_cast(unpacked_tensor.DataAsSpan()[0]) + : static_cast(unpacked_tensor.DataAsByteSpan()[0]); // per channel quantized handling if (!unpacked_tensor.dims().empty() && unpacked_tensor.dims()[0] != 0 && unpacked_tensor.dims()[0] != 1) { @@ -482,7 +492,8 @@ static bool IsInternalQuantizedNodeUnit(const NodeUnit& node_unit) { int32_t input_type; ORT_ENFORCE(GetType(*node.InputDefs()[0], input_type)); - return input_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8 || input_type == ONNX_NAMESPACE::TensorProto_DataType_INT8; + return input_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8 || + input_type == ONNX_NAMESPACE::TensorProto_DataType_INT8; } bool GetType(const NodeArg& node_arg, int32_t& type) { diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_util.h b/onnxruntime/core/providers/vsinpu/vsinpu_util.h index 9ec580bf02e7..09ed8d07dcd8 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_util.h +++ b/onnxruntime/core/providers/vsinpu/vsinpu_util.h @@ -118,7 +118,7 @@ bool IsQuantizedBinaryOp(QuantizedOpType quant_op_type); bool HasValidBinaryOpQuantizedInputTypes(const NodeUnit& node_unit); void GetQuantizationScaleAndZeroPoint( - const InitializedTensorSet& initializers, const NodeUnitIODef& io_def, const Path& model_path, + const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const std::filesystem::path& model_path, float& scale, int32_t& zero_point, std::optional>& pcq_scales, std::optional>& pcq_zps); diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 4eaa4855ff7b..847db6a9975c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -37,7 +37,6 @@ void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod // skip the weight for conv as we need to transpose for preferred layout NHWC. if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // W - model_builder.AddInputToSkip(node.InputDefs()[1]->Name()); } } @@ -168,7 +167,7 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, if (is_conv == 1) dest_shape = {out_t, h_t, w_t, in_t}; // L_0231 else - dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv weight + dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv and convTranspose weight SafeInt num_elements = SafeInt(Product(dest_shape)); @@ -186,7 +185,6 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: element_size = sizeof(float); break; - break; default: break; } @@ -232,9 +230,11 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val output = emscripten::val::object(); + const auto& initializers(model_builder.GetInitializerTensors()); std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); @@ -249,6 +249,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; const bool is_conv1d = input_shape.size() == 3 && weight_shape.size() == 3; + const bool is_constant_weight = Contains(initializers, weight_name); // Support conv1d by prepending a 1 or 2 size dimensions. if (is_conv1d) { // Reshape input. @@ -274,12 +275,15 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val options = emscripten::val::object(); ORT_RETURN_IF_ERROR(SetConvBaseOptions( model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger)); + bool depthwise = false; if (op_type == "Conv" || op_type == "ConvInteger") { int groups = options["groups"].as(); if (is_nhwc) { - bool depthwise = (groups == input_shape[3] && groups != 1); + depthwise = (groups == input_shape[3] && groups != 1); options.set("inputLayout", emscripten::val("nhwc")); - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise, is_conv1d)); + if (is_constant_weight) { + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise, is_conv1d)); + } if (!depthwise) { options.set("filterLayout", emscripten::val("ohwi")); } else { @@ -290,15 +294,37 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (is_nhwc) { options.set("inputLayout", emscripten::val("nhwc")); options.set("filterLayout", emscripten::val("ohwi")); - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, false, is_conv1d)); + if (is_constant_weight) { + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, true, is_conv1d)); + } } } emscripten::val filter = model_builder.GetOperand(weight_name); - if (!is_nhwc && is_conv1d) { - // Reshape weight to 4D for conv1d with NCHW preferred layout. - std::vector new_shape = GetVecUint32FromVecInt64(weight_shape); - filter = model_builder.GetBuilder().call("reshape", filter, emscripten::val::array(new_shape)); + + if (is_conv1d) { + // Reshape weight to 4D for conv1d. + if (!is_nhwc || !is_constant_weight) { + // The weight_shape has been appended 1's, reshape weight operand. + std::vector new_shape = GetVecUint32FromVecInt64(weight_shape); + filter = model_builder.GetBuilder().call("reshape", filter, emscripten::val::array(new_shape)); + } + } + + emscripten::val transpose_options = emscripten::val::object(); + if (is_nhwc && !is_constant_weight) { + // For NHWC preferred layout, if the weight is input: + // - Transpose it from iohw -> ohwi for convTranspose. + // - Transpose it from oihw -> ihwo for depthwise conv. + // - Transpose it from oihw -> ohwi for conv. + std::vector perm(4); + if (op_type == "ConvTranspose" || depthwise) { + perm = {1, 2, 3, 0}; // L_1230 for depthwise conv and convTranspose weight + } else { + perm = {0, 2, 3, 1}; // L_0231 + } + transpose_options.set("permutation", emscripten::val::array(perm)); + filter = model_builder.GetBuilder().call("transpose", filter, transpose_options); } if (op_type == "Conv") { @@ -371,13 +397,6 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return false; } - // WebNN CPU backend (XNNPACK) requires the filter operand to be a constant. - // https://github.com/google/XNNPACK/blob/master/src/subgraph/convolution-2d.c#L739 - if (device_type == WebnnDeviceType::CPU && !Contains(initializers, input_defs[1]->Name())) { - LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known"; - return false; - } - return true; } diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 13ed29667deb..e839d6d17b7d 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -17,24 +17,12 @@ namespace onnxruntime { -WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags, - const std::string& webnn_threads_number, const std::string& webnn_power_flags) +WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags) : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { - // Create WebNN context and graph builder. - const emscripten::val ml = emscripten::val::global("navigator")["ml"]; - if (!ml.as()) { - ORT_THROW("Failed to get ml from navigator."); - } - emscripten::val context_options = emscripten::val::object(); - context_options.set("deviceType", emscripten::val(webnn_device_flags)); // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { preferred_layout_ = DataLayout::NHWC; wnn_device_type_ = webnn::WebnnDeviceType::CPU; - // Set "numThreads" if it's not default 0. - if (webnn_threads_number.compare("0") != 0) { - context_options.set("numThreads", stoi(webnn_threads_number)); - } } else { preferred_layout_ = DataLayout::NCHW; if (webnn_device_flags.compare("gpu") == 0) { @@ -45,11 +33,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f ORT_THROW("Unknown WebNN deviceType."); } } - if (webnn_power_flags.compare("default") != 0) { - context_options.set("powerPreference", emscripten::val(webnn_power_flags)); - } - wnn_context_ = ml.call("createContext", context_options).await(); + wnn_context_ = emscripten::val::module_property("currentContext"); if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } @@ -96,6 +81,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view const auto& logger = *GetLogger(); + if (!wnn_builder_.as()) { + // The GetCapability function may be called again after Compile due to the logic in the + // PartitionOnnxFormatModel function (see onnxruntime/core/framework/graph_partitioner.cc). + // We need to re-create the wnn_builder_ here to avoid it's been released in last Compile. + wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); + } + const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder_, wnn_device_type_, logger); if (node_groups.empty()) { @@ -337,6 +329,9 @@ common::Status WebNNExecutionProvider::Compile(const std::vector> @@ -43,8 +42,8 @@ class WebNNExecutionProvider : public IExecutionProvider { std::shared_ptr GetKernelRegistry() const override; private: - emscripten::val wnn_context_ = emscripten::val::object(); - emscripten::val wnn_builder_ = emscripten::val::object(); + emscripten::val wnn_context_ = emscripten::val::undefined(); + mutable emscripten::val wnn_builder_ = emscripten::val::undefined(); DataLayout preferred_layout_; webnn::WebnnDeviceType wnn_device_type_; diff --git a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc index 11acec8b1f35..7792aeabaabf 100644 --- a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc +++ b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc @@ -10,27 +10,22 @@ using namespace onnxruntime; namespace onnxruntime { struct WebNNProviderFactory : IExecutionProviderFactory { - WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_threads_number, - const std::string& webnn_power_flags) - : webnn_device_flags_(webnn_device_flags), webnn_threads_number_(webnn_threads_number), webnn_power_flags_(webnn_power_flags) {} + explicit WebNNProviderFactory(const std::string& webnn_device_flags) + : webnn_device_flags_(webnn_device_flags) {} ~WebNNProviderFactory() override {} std::unique_ptr CreateProvider() override; std::string webnn_device_flags_; - std::string webnn_threads_number_; - std::string webnn_power_flags_; }; std::unique_ptr WebNNProviderFactory::CreateProvider() { - return std::make_unique(webnn_device_flags_, webnn_threads_number_, webnn_power_flags_); + return std::make_unique(webnn_device_flags_); } std::shared_ptr WebNNProviderFactoryCreator::Create( const ProviderOptions& provider_options) { - return std::make_shared(provider_options.at("deviceType"), - provider_options.at("numThreads"), - provider_options.at("powerPreference")); + return std::make_shared(provider_options.at("deviceType")); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index 0366d9f893f7..b815cc1570c9 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -5,7 +5,7 @@ #include -#include "core/common/gsl.h" +#include #include "core/common/inlined_containers_fwd.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/transpose_helper.h" diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 7102dbfc750e..4c782f647371 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -10,7 +10,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/data_types.h" #include "core/framework/error_code_helper.h" #include "core/framework/onnxruntime_typeinfo.h" diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 688ee76c591f..db8b97f6d2c1 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -127,11 +127,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, } else if (strcmp(provider_name, "WEBNN") == 0) { #if defined(USE_WEBNN) std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "cpu"); - std::string numThreads = options->value.config_options.GetConfigOrDefault("numThreads", "0"); - std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "default"); provider_options["deviceType"] = deviceType; - provider_options["numThreads"] = numThreads; - provider_options["powerPreference"] = powerPreference; options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options)); #else status = create_not_supported_status(); diff --git a/onnxruntime/core/util/matrix_layout.h b/onnxruntime/core/util/matrix_layout.h index 43843da3fb96..dbf961ab5b03 100644 --- a/onnxruntime/core/util/matrix_layout.h +++ b/onnxruntime/core/util/matrix_layout.h @@ -15,7 +15,7 @@ #pragma once #include -#include "core/common/gsl.h" +#include #if defined(_MSC_VER) #define ORT_FORCEINLINE __forceinline diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 9bc2328cc71b..ac959d5c061f 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -206,10 +206,9 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GemmFloat8": self._infer_GemmFloat8, "GroupNorm": self._infer_GroupNorm, "GroupQueryAttention": self._infer_GroupQueryAttention, - "SparseAttention": self._infer_SparseAttention, - "SkipGroupNorm": self._infer_SkipGroupNorm, "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, + "MatMulNBits": self._infer_MatMulNBits, "MultiHeadAttention": self._infer_MultiHeadAttention, "NhwcConv": self._infer_NhwcConv, "PackedAttention": self._infer_PackedAttention, @@ -223,8 +222,10 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "RestorePadding": self._infer_RestorePadding, "RotaryEmbedding": self._infer_RotaryEmbedding, "SimplifiedLayerNormalization": self._infer_LayerNormalization, + "SkipGroupNorm": self._infer_SkipGroupNorm, "SkipLayerNormalization": self._infer_SkipLayerNormalization, "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, + "SparseAttention": self._infer_SparseAttention, } self.aten_op_dispatcher_ = { "embedding": self._infer_Gather, @@ -1256,6 +1257,25 @@ def _infer_MatMul(self, node): # noqa: N802 def _infer_MatMulInteger(self, node): # noqa: N802 self._compute_matmul_shape(node, onnx.TensorProto.INT32) + def _infer_MatMulNBits(self, node): # noqa: N802 + lhs_shape = self._get_shape(node, 0) + rhs_shape = [get_attribute(node, "K"), get_attribute(node, "N")] + lhs_rank = len(lhs_shape) + assert lhs_rank > 0 + if lhs_rank == 1: + new_shape = rhs_shape[1:] + else: + new_shape = lhs_shape[:-1] + rhs_shape[1:] + # merge reduce dim + self._check_merged_dims( + [lhs_shape[-1], rhs_shape[0]], + allow_broadcast=False, + ) + # infer output_dtype from input type when not specified + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + def _infer_NonMaxSuppression(self, node): # noqa: N802 selected = str(self._new_symbolic_dim_from_output(node)) vi = self.known_vi_[node.output[0]] diff --git a/onnxruntime/python/tools/transformers/dev_benchmark.cmd b/onnxruntime/python/tools/transformers/dev_benchmark.cmd index 7a9b3254a170..82137de3c0f3 100644 --- a/onnxruntime/python/tools/transformers/dev_benchmark.cmd +++ b/onnxruntime/python/tools/transformers/dev_benchmark.cmd @@ -41,7 +41,7 @@ set input_counts=1 REM Pretrained transformers models can be a subset of: bert-base-cased roberta-base gpt2 distilgpt2 distilbert-base-uncased set models_to_test=bert-base-cased -REM If you have mutliple GPUs, you can choose one GPU for test. Here is an example to use the second GPU: +REM If you have multiple GPUs, you can choose one GPU for test. Here is an example to use the second GPU: REM set CUDA_VISIBLE_DEVICES=1 REM This script will generate a logs file with a list of commands used in tests. @@ -163,4 +163,4 @@ IF %FileSize% LSS 10 goto :EOF python -c "import sys; lines=sys.stdin.readlines(); h=lines[0]; print(''.join([h]+list(sorted(set(lines)-set([h])))))" < %1 > sort_%1 FindStr "[^,]" sort_%1 > summary_%1 DEL sort_%1 -goto :EOF \ No newline at end of file +goto :EOF diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index 95f40af3fd74..9cc4878e8022 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -694,7 +694,7 @@ def __init__(self, model, num_heads, hidden_size): self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask) self.layer_norm_fusion = FusionSimplifiedLayerNormalization(self) self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self) - # TODO: consider retrive max_distance from model. + # TODO: consider retrieve max_distance from model. # math.log(max_distance / (num_buckets // 2)) self.rpb_fusion = FusionRelativePositionBiasBlock(self, 128) diff --git a/onnxruntime/python/tools/transformers/run_benchmark.sh b/onnxruntime/python/tools/transformers/run_benchmark.sh index 64d6ecde618f..77d0c3a76624 100755 --- a/onnxruntime/python/tools/transformers/run_benchmark.sh +++ b/onnxruntime/python/tools/transformers/run_benchmark.sh @@ -62,7 +62,7 @@ input_counts=1 # Pretrained transformers models can be a subset of: bert-base-cased roberta-base gpt2 distilgpt2 distilbert-base-uncased models_to_test="bert-base-cased roberta-base distilbert-base-uncased" -# If you have mutliple GPUs, you can choose one GPU for test. Here is an example to use the second GPU: +# If you have multiple GPUs, you can choose one GPU for test. Here is an example to use the second GPU: # export CUDA_VISIBLE_DEVICES=1 # This script will generate a logs file with a list of commands used in tests. diff --git a/onnxruntime/test/common/random_generator.h b/onnxruntime/test/common/random_generator.h index 336e0f197fcc..9ab4a82463d5 100644 --- a/onnxruntime/test/common/random_generator.h +++ b/onnxruntime/test/common/random_generator.h @@ -8,7 +8,7 @@ #include "gtest/gtest.h" -#include "core/common/gsl.h" +#include #include "core/common/common.h" #include "core/common/optional.h" #include "core/common/type_utils.h" diff --git a/onnxruntime/test/common/tensor_op_test_utils.h b/onnxruntime/test/common/tensor_op_test_utils.h index b831507988da..e0891c7ced63 100644 --- a/onnxruntime/test/common/tensor_op_test_utils.h +++ b/onnxruntime/test/common/tensor_op_test_utils.h @@ -6,7 +6,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "gtest/gtest.h" diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 6ce9f5de68f1..5f94d30112f0 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -4,7 +4,7 @@ #include #include #include "gtest/gtest.h" -#include "core/common/gsl.h" +#include #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" diff --git a/onnxruntime/test/contrib_ops/greedy_search_test.cc b/onnxruntime/test/contrib_ops/greedy_search_test.cc index 73da82d4bb03..79070f0788f2 100644 --- a/onnxruntime/test/contrib_ops/greedy_search_test.cc +++ b/onnxruntime/test/contrib_ops/greedy_search_test.cc @@ -4,7 +4,7 @@ #include #include #include "gtest/gtest.h" -#include "core/common/gsl.h" +#include #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" diff --git a/onnxruntime/test/contrib_ops/sampling_test.cc b/onnxruntime/test/contrib_ops/sampling_test.cc index d987a1cae427..69789b84832e 100644 --- a/onnxruntime/test/contrib_ops/sampling_test.cc +++ b/onnxruntime/test/contrib_ops/sampling_test.cc @@ -4,7 +4,7 @@ #include #include #include "gtest/gtest.h" -#include "core/common/gsl.h" +#include #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index 1804c09043c7..9278541b0751 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -139,7 +139,7 @@ TEST(TransformerTest, CastRemovalDoesNotLowerPrecisionTest) { status = graph.Resolve(); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); - // When casting f64 -> f32 -> f64 we should not be optimising away the cast since there is a loss of precision. + // When casting f64 -> f32 -> f64 we should not be optimizing away the cast since there is a loss of precision. EXPECT_EQ(graph.NumberOfNodes(), 2); } @@ -171,7 +171,7 @@ TEST(TransformerTest, CastRemovalDoesNotRemoveSignednessTest) { status = graph.Resolve(); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); - // When casting i32 -> ui32 -> i32 we should not be optimising away the cast since applying the casts produces a very different result. + // When casting i32 -> ui32 -> i32 we should not be optimizing away the cast since applying the casts produces a very different result. EXPECT_EQ(graph.NumberOfNodes(), 2); } diff --git a/onnxruntime/test/framework/test_utils.h b/onnxruntime/test/framework/test_utils.h index 0a99b4bc8021..51b02ee3e7f8 100644 --- a/onnxruntime/test/framework/test_utils.h +++ b/onnxruntime/test/framework/test_utils.h @@ -9,7 +9,7 @@ #include "core/providers/cpu/cpu_execution_provider.h" #include "core/framework/ort_value.h" -#include "core/common/gsl.h" +#include #ifdef USE_CUDA #include "core/providers/providers.h" diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 71a6123b868b..f391027de4d5 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -419,9 +419,10 @@ static size_t SQNBitGemmRegisterAllShortExecuteTests() { return count; } -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - if (is_short_execute) { - return SQNBitGemmRegisterAllShortExecuteTests() > 0; - } - return false; -}); +static UNUSED_VARIABLE bool added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return SQNBitGemmRegisterAllShortExecuteTests(); + } + return 0; + }); diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 1e2d34e5aefc..0282d09f340b 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -339,8 +339,10 @@ class ModelTestBuilder { bool use_ms_domain = false) { std::vector input_args; input_args.push_back(input_arg); - input_args.push_back(Make1DInitializer(input_scales)); - input_args.push_back(Make1DInitializer(input_zero_points)); + + std::vector qparams_shape = {static_cast(input_scales.size())}; + input_args.push_back(MakeInitializer(qparams_shape, input_scales)); + input_args.push_back(MakeInitializer(qparams_shape, input_zero_points)); std::string domain = use_ms_domain ? kMSDomain : ""; return AddNode("QuantizeLinear", input_args, {output_arg}, domain, attributes); @@ -415,8 +417,10 @@ class ModelTestBuilder { bool use_ms_domain = false) { std::vector input_args; input_args.push_back(input_arg); - input_args.push_back(Make1DInitializer(input_scales)); - input_args.push_back(Make1DInitializer(input_zero_points)); + + std::vector qparams_shape = {static_cast(input_scales.size())}; + input_args.push_back(MakeInitializer(qparams_shape, input_scales)); + input_args.push_back(MakeInitializer(qparams_shape, input_zero_points)); std::string domain = use_ms_domain ? kMSDomain : ""; return AddNode("DequantizeLinear", input_args, {output_arg}, domain, attributes); diff --git a/onnxruntime/test/optimizer/initializer_test.cc b/onnxruntime/test/optimizer/initializer_test.cc index 9e55d9b2ef92..522e96e762d5 100644 --- a/onnxruntime/test/optimizer/initializer_test.cc +++ b/onnxruntime/test/optimizer/initializer_test.cc @@ -8,7 +8,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "gtest/gtest.h" diff --git a/onnxruntime/test/platform/file_io_test.cc b/onnxruntime/test/platform/file_io_test.cc index 3611e3b33446..ccc703716844 100644 --- a/onnxruntime/test/platform/file_io_test.cc +++ b/onnxruntime/test/platform/file_io_test.cc @@ -14,7 +14,7 @@ #include #endif -#include "core/common/gsl.h" +#include #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/cpu/ml/array_feature_extractor_test.cc b/onnxruntime/test/providers/cpu/ml/array_feature_extractor_test.cc index f7b8a61f48da..c7fc73456dcb 100644 --- a/onnxruntime/test/providers/cpu/ml/array_feature_extractor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/array_feature_extractor_test.cc @@ -5,7 +5,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" -#include "core/common/gsl.h" +#include using namespace std; namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 4c77908322df..421561a5a859 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -5,7 +5,7 @@ #include "boost/mp11.hpp" -#include "core/common/gsl.h" +#include #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc index 1da9c9df299c..af54ae96ef86 100644 --- a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc @@ -35,8 +35,10 @@ void RunSliceTest(const std::vector& input_dims, excluded_providers.insert(excluded_providers_input.cbegin(), excluded_providers_input.cend()); // NNAPI EP does not support empty output + // VSINPU EP does not support empty output if (std::any_of(output_dims.cbegin(), output_dims.cend(), [](int64_t i) { return i == 0; })) { excluded_providers.insert(kNnapiExecutionProvider); + excluded_providers.insert(kVSINPUExecutionProvider); } // TODO: ORT behavior when step < 0 and end = INT_MAX is wrong. Fix it and @@ -515,6 +517,9 @@ TEST(SliceTest, Slice1D_ReverseAllAxes_1) { if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: Expected output shape [{2,2}] did not match run output shape [{0,0}] for output"; } + if (DefaultVSINPUExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Expected output shape [{4}] did not match run output shape [{0}] for output"; + } RunSliceTest({4}, {1.0f, 2.0f, 3.0f, 4.0f}, diff --git a/onnxruntime/test/providers/cpu/tensor/where_op_test.cc b/onnxruntime/test/providers/cpu/tensor/where_op_test.cc index 6237521b34df..03db49a313af 100644 --- a/onnxruntime/test/providers/cpu/tensor/where_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/where_op_test.cc @@ -3,7 +3,7 @@ #include "gtest/gtest.h" -#include "core/common/gsl.h" +#include #include "test/providers/provider_test_utils.h" diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc index d8384b432786..022c6250138d 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc @@ -10,7 +10,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cuda/cuda_execution_provider_info.h" #include "core/providers/cuda/cuda_allocator.h" diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index b3e1025e7367..1af7bdea68b6 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -14,6 +14,7 @@ #include "test/common/tensor_op_test_utils.h" #include "test/framework/test_utils.h" #include "test/util/include/asserts.h" +#include "test/util/include/current_test_name.h" #include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" #include "test/util/include/test/test_environment.h" @@ -36,10 +37,6 @@ using namespace ::onnxruntime::logging; namespace onnxruntime { namespace test { -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - -#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - #if !defined(ORT_MINIMAL_BUILD) // Since NNAPI EP handles Reshape and Flatten differently, @@ -65,7 +62,8 @@ TEST(NnapiExecutionProviderTest, ReshapeFlattenTest) { feeds.insert(std::make_pair("X", ml_value_x)); feeds.insert(std::make_pair("Y", ml_value_y)); - RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.ReshapeFlattenTest", + RunAndVerifyOutputsWithEP(model_file_name, + CurrentTestName(), std::make_unique(0), feeds); #else @@ -88,7 +86,8 @@ TEST(NnapiExecutionProviderTest, SigmoidSupportedInputRankTest) { NameMLValMap feeds; feeds.insert(std::make_pair("X", ml_value_x)); - RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.SigmoidSupportedInputRankTest", + RunAndVerifyOutputsWithEP(model_file_name, + CurrentTestName(), std::make_unique(0), feeds, {ExpectedEPNodeAssignment::None} /* params */); #else @@ -115,7 +114,8 @@ TEST(NnapiExecutionProviderTest, DynamicGraphInputTest) { NameMLValMap feeds; feeds.insert(std::make_pair("X", ml_value_x)); - RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.DynamicGraphInputTest", + RunAndVerifyOutputsWithEP(model_file_name, + CurrentTestName(), std::make_unique(0), feeds); #else @@ -144,7 +144,8 @@ TEST(NnapiExecutionProviderTest, InternalUint8SupportTest) { NameMLValMap feeds; feeds.insert(std::make_pair("X", ml_value_x)); - RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.InternalUint8SupportTest", + RunAndVerifyOutputsWithEP(model_file_name, + CurrentTestName(), std::make_unique(0), feeds); #else @@ -208,7 +209,8 @@ TEST(NnapiExecutionProviderTest, FunctionTest) { feeds.insert(std::make_pair("Y", ml_value_y)); feeds.insert(std::make_pair("Z", ml_value_z)); - RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.FunctionTest", + RunAndVerifyOutputsWithEP(model_file_name, + CurrentTestName(), std::make_unique(0), feeds); #else @@ -273,7 +275,8 @@ static void RunQDQModelTest( const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); #if defined(__ANDROID__) - RunAndVerifyOutputsWithEP(model_data_span, "NnapiExecutionProviderTest.TestQDQModel", + RunAndVerifyOutputsWithEP(model_data_span, + CurrentTestName(), std::make_unique(0), helper.feeds_, params); #else @@ -513,6 +516,31 @@ TEST(NnapiExecutionProviderTest, TestGather) { {ExpectedEPNodeAssignment::All}); } +TEST(NnapiExecutionProviderTest, SharedInitializersDoNotGetSkipped) { + // NNAPI EP's Clip op builder will mark the max initializer as skipped but it is also used by the Div op. + // Test that the shared initializer is still present in the NNAPI model for the Div op. + constexpr auto* model_file_name = ORT_TSTR("testdata/clip_div_shared_initializer.onnx"); + +#if defined(__ANDROID__) + AllocatorPtr cpu_allocator = std::make_shared(); + + std::vector x_dims{3, 2}; + std::vector x_values(3.0f, 3 * 2); + OrtValue ml_value_x; + CreateMLValue(cpu_allocator, x_dims, x_values, &ml_value_x); + + NameMLValMap feeds{{"input_0", ml_value_x}}; + + RunAndVerifyOutputsWithEP(model_file_name, + CurrentTestName(), + std::make_unique(0), + feeds, + {ExpectedEPNodeAssignment::All}); +#else + TestModelLoad(model_file_name, std::make_unique(0), ExpectedEPNodeAssignment::All); +#endif +} + #endif // !(ORT_MINIMAL_BUILD) TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) { @@ -541,7 +569,8 @@ TEST(NnapiExecutionProviderTest, TestOrtFormatModel) { NameMLValMap feeds; feeds.insert(std::make_pair("Input3", ml_value)); - RunAndVerifyOutputsWithEP(model_file_name, "NnapiExecutionProviderTest.TestOrtFormatModel", + RunAndVerifyOutputsWithEP(model_file_name, + CurrentTestName(), std::make_unique(0), feeds); #else diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index 5177a629ce29..b07951d2a2e6 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -182,7 +182,12 @@ static GetTestQDQModelFn BuildQDQPerChannelConvTestCase(const s static_cast(weight_quant_axis), true); TensorShape weights_shape = weights_def.GetTensorShape(); - std::vector quantized_weights(weights_shape.Size()); + std::vector quantized_weights; + size_t num_weight_storage_elems = weights_shape.Size(); + if constexpr (std::is_same_v || std::is_same_v) { + num_weight_storage_elems = Int4x2::CalcNumInt4Pairs(weights_shape.Size()); + } + quantized_weights.resize(num_weight_storage_elems); QuantizeValues(weights_def.GetRawData(), quantized_weights, weights_shape, weight_scales, weight_zero_points, weight_quant_axis); @@ -727,6 +732,80 @@ TEST_F(QnnHTPBackendTests, ConvU8S8S32_PerChannel) { 13); // opset } +// Test per-channel QDQ Conv with INT4 weights. in0: u16, in1 (weight): s4, in2 (bias): s32, out: u8 +TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(0.0f, 1.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + 0, // weight quant axis + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21); // opset +} + +// Test per-channel QDQ Conv with INT4 weights. in0: u16, in1 (weight): s4, in2 (bias): s32, out: u8 +// TODO(adrianlizarraga): Investigate inaccuracy for QNN EP. +// +// Output values for all EPs: +// CPU EP (f32 model): 25.143 21.554 17.964 10.785 7.195 3.605 -3.574 -7.164 -10.753 +// CPU EP (qdq model): 24.670 21.103 17.536 10.254 6.689 2.972 -4.161 -7.728 -10.700 +// QNN EP (qdq model): 27.186 27.186 27.186 21.541 6.685 -8.022 -10.548 -10.548 -10.548 +TEST_F(QnnHTPBackendTests, DISABLED_ConvU16S4S32_PerChannel_AccuracyIssue) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + // Wrote out input data explicitly for easier reproduction. + // std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, TensorShape(input_shape).Size()); + std::vector input_data = {-10.000f, -9.355f, -8.710f, -8.065f, -7.419f, -6.774f, -6.129f, -5.484f, -4.839f, + -4.194f, -3.548f, -2.903f, -2.258f, -1.613f, -0.968f, -0.323f, 0.323f, 0.968f, + 1.613f, 2.258f, 2.903f, 3.548f, 4.194f, 4.839f, 5.484f, 6.129f, 6.774f, + 7.419f, 8.065f, 8.710f, 9.355f, 10.000f}; + + // std::vector weight_data = GetFloatDataInRange(-1.0f, 1.0f, TensorShape(weight_shape).Size()); + std::vector weight_data = {-1.000f, -0.913f, -0.826f, -0.739f, -0.652f, -0.565f, -0.478f, -0.391f, -0.304f, + -0.217f, -0.130f, -0.043f, 0.043f, 0.130f, 0.217f, 0.304f, 0.391f, 0.478f, + 0.565f, 0.652f, 0.739f, 0.826f, 0.913f, 1.000f}; + + // std::vector bias_data = GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size()); + std::vector bias_data = {-1.000f, 0.000f, 1.000f}; + + TestInputDef input_def(input_shape, false, input_data); + TestInputDef weight_def(weight_shape, true, weight_data); + TestInputDef bias_def(bias_shape, true, bias_data); + + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + 0, // weight quant axis + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21); // opset +} + // Test per-channel QDQ Conv is rejected with weight axis != 0 TEST_F(QnnHTPBackendTests, Conv_PerChannel_UnsupportedAxis) { std::vector input_shape = {1, 2, 4, 4}; diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 6173f46839a8..9489d354755e 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -236,6 +236,51 @@ TEST_F(QnnHTPBackendTests, TestConvWithExternalData) { Ort::Session session(*ort_env, ort_model_path, so); } +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +TEST_F(QnnHTPBackendTests, RunConvInt4Model) { + Ort::SessionOptions so; + + so.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Disable fallback to the CPU EP. + so.SetGraphOptimizationLevel(ORT_ENABLE_ALL); + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + + so.AppendExecutionProvider("QNN", options); + + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "conv.int4_weights.qdq.onnx"; + Ort::Session session(*ort_env, ort_model_path, so); + + TensorShape input_shape = {1, 3, 8, 8}; + std::vector input0_data(input_shape.Size(), 0.2f); + + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector ort_inputs; + std::vector ort_input_names; + + // Add input0 + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), &input_shape[0], input_shape.NumDimensions())); + ort_input_names.push_back("input_0"); + + // Run session and get outputs + std::array output_names{"output_0"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output shape. + Ort::Value& ort_output = ort_outputs[0]; + auto typeshape = ort_output.GetTensorTypeAndShapeInfo(); + std::vector output_shape = typeshape.GetShape(); + + EXPECT_THAT(output_shape, ::testing::ElementsAre(1, 5, 6, 6)); +} +#endif // #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + // Helper function that runs an ONNX model with a NHWC Resize operator to test that // type/shape inference succeeds during layout transformation. // Refer to onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h. diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index dd47e1df8000..ad54e644af3f 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -34,6 +34,15 @@ struct QuantParams { QType zero_point; static QuantParams Compute(float rmin, float rmax, bool symmetric = false) { + return Compute( + rmin, + rmax, + static_cast(std::numeric_limits::min()), + static_cast(std::numeric_limits::max()), + symmetric); + } + + static QuantParams Compute(float rmin, float rmax, QType qmin, QType qmax, bool symmetric = false) { // Ensure a minimum range of 0.0001 (required by QNN) rmax = std::max(rmax, rmin + 0.0001f); @@ -41,27 +50,27 @@ struct QuantParams { rmin = std::min(rmin, 0.0f); rmax = std::max(rmax, 0.0f); - constexpr float qmin = static_cast(std::numeric_limits::min()); - constexpr float qmax = static_cast(std::numeric_limits::max()); - if (symmetric) { const float abs_max = std::max(std::abs(rmin), std::abs(rmax)); rmax = abs_max; rmin = -abs_max; } - const float scale = (rmax - rmin) / (qmax - qmin); + float qmin_flt = static_cast(qmin); + float qmax_flt = static_cast(qmax); + const float scale = (rmax - rmin) / (qmax_flt - qmin_flt); float initial_zero_point = 0.0f; if (symmetric) { // Symmetric uses same formula for zero-point as asymmetric, but we can cancel out terms for // increased numerical accuracy. - initial_zero_point = (qmin + qmax) / 2.0f; + initial_zero_point = (qmin_flt + qmax_flt) / 2.0f; } else { - initial_zero_point = qmin - (rmin / scale); + initial_zero_point = qmin_flt - (rmin / scale); } - const QType zero_point = static_cast(RoundHalfToEven(std::max(qmin, std::min(qmax, initial_zero_point)))); + const QType zero_point = static_cast(RoundHalfToEven(std::max(qmin_flt, + std::min(qmax_flt, initial_zero_point)))); return QuantParams{scale, zero_point}; } @@ -238,7 +247,7 @@ struct TestInputDef { assert(which_type == 0); const std::vector& raw_data = std::get(data_info_).data; - std::pair init_range(std::numeric_limits::max(), std::numeric_limits::min()); + std::pair init_range(std::numeric_limits::max(), std::numeric_limits::lowest()); std::vector> per_axis_ranges(num_ranges, init_range); TensorShape shape(shape_); size_t num_blocks = shape.SizeToDimension(axis); @@ -292,6 +301,37 @@ static void GetTestInputQuantParamsPerChannel(const TestInputDef& input_d } } +// Define functions to get the quantization parameters (i.e., scale/zp) for input data that will be quantized +// as int4 per-channel. +#define DEF_GET_INPUT_QPARAMS_PER_CHAN_INT4_FUNC(INT4x2_TYPE) \ + template <> \ + inline void GetTestInputQuantParamsPerChannel(const TestInputDef& input_def, \ + std::vector& scales, \ + std::vector& zero_points, \ + size_t axis, bool symmetric) { \ + using UnpackedType = typename INT4x2_TYPE::UnpackedType; \ + const auto f32_ranges = input_def.GetRangePerChannel(axis); \ + const size_t num_ranges = f32_ranges.size(); \ + \ + scales.resize(num_ranges); \ + zero_points.resize(INT4x2_TYPE::CalcNumInt4Pairs(num_ranges)); \ + \ + for (size_t i = 0; i < num_ranges; i++) { \ + const auto& range = f32_ranges[i]; \ + QuantParams params = QuantParams::Compute(range.first, range.second, \ + INT4x2_TYPE::min_val, \ + INT4x2_TYPE::max_val, symmetric); \ + scales[i] = params.scale; \ + \ + size_t r = i >> 1; \ + size_t c = i & 0x1; \ + zero_points[r].SetElem(c, params.zero_point); \ + } \ + } + +DEF_GET_INPUT_QPARAMS_PER_CHAN_INT4_FUNC(Int4x2) +DEF_GET_INPUT_QPARAMS_PER_CHAN_INT4_FUNC(UInt4x2) + template static void QuantizeValues(gsl::span input, gsl::span output, const TensorShape& shape, gsl::span scales, gsl::span zero_points, @@ -332,6 +372,52 @@ static void QuantizeValues(gsl::span input, gsl::span \ + inline void QuantizeValues(gsl::span input, \ + gsl::span output, \ + const TensorShape& shape, \ + gsl::span scales, \ + gsl::span zero_points, \ + std::optional axis) { \ + using UnpackedType = typename INT4x2_TYPE::UnpackedType; \ + const size_t input_rank = shape.NumDimensions(); \ + const size_t num_int4_elems = static_cast(shape.Size()); \ + ORT_ENFORCE(input.size() == num_int4_elems); \ + ORT_ENFORCE(output.size() == INT4x2_TYPE::CalcNumInt4Pairs(num_int4_elems)); \ + \ + size_t block_count = 1; \ + size_t broadcast_dim = 1; \ + size_t block_size = num_int4_elems; \ + \ + if (axis.has_value()) { \ + size_t axis_no_neg = *axis < 0 ? static_cast(*axis) + input_rank : static_cast(*axis); \ + block_count = shape.SizeToDimension(axis_no_neg); \ + broadcast_dim = shape[axis_no_neg]; \ + block_size = shape.SizeFromDimension(axis_no_neg + 1); \ + } \ + \ + ORT_ENFORCE(scales.size() == broadcast_dim); \ + ORT_ENFORCE(zero_points.empty() || zero_points.size() == INT4x2_TYPE::CalcNumInt4Pairs(broadcast_dim)); \ + \ + size_t i = 0; \ + \ + for (size_t n = 0; n < block_count; n++) { \ + for (size_t bd = 0; bd < broadcast_dim; bd++) { \ + size_t bd_i = bd >> 1; /* bd / 2 */ \ + size_t bd_j = bd & 0x1; /* bd % 2 */ \ + UnpackedType zp = !zero_points.empty() ? zero_points[bd_i].GetElem(bd_j) : 0; \ + QUANT_FUNC(&input[i], output.data(), i, i + block_size, scales[bd], INT4x2_TYPE(zp, 0), nullptr); \ + i += block_size; \ + } \ + } \ + assert(i == (block_count * broadcast_dim * block_size)); \ + } + +DEF_QUANTIZE_VALUES_INT4_FUNC(Int4x2, ParQuantizeLinearStdS4) +DEF_QUANTIZE_VALUES_INT4_FUNC(UInt4x2, ParQuantizeLinearStdU4) + /** * Inferences a given serialized model. Returns output values via an out-param. * @@ -414,6 +500,10 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); + + // Uncomment to dump LOGGER() output to stdout. + // logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(log_severity); // Create float model and serialize it to a string. diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index eca1430448e8..29680c98fb4d 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -594,6 +594,55 @@ def test_dequantize_linear_ms_domain(self): ] self._check_shapes(graph, inferred.graph, expected_shapes) + def test_matmulnbits(self): + """ + Test ORT MatMulNBits op. + Check that the output shape is propagated from the inputs and that the output data + type comes from the first input. + """ + b_np = numpy.random.randint(0, 255, (4, 1, 8), numpy.uint8) + b = numpy_helper.from_array(b_np, name="b") + scale_np = numpy.random.rand(4).astype(numpy.float32) + scale = numpy_helper.from_array(scale_np, name="scale") + zero_point_np = numpy.random.randint(0, 255, (4), numpy.uint8) + zero_point = numpy_helper.from_array(zero_point_np, name="zero_point") + + initializers = [b, scale, zero_point] + + kwargs = {"K": 10, "N": 4, "block_size": 16} + + nodes = [ + helper.make_node( + "MatMulNBits", + inputs=[ + "input_f32", + "b", + "scale", + "zero_point", + ], + outputs=["output_f32"], + **kwargs, + ), + ] + + inputs = [ + helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["x", 2, 3, 10]), + ] + + outputs = [ + helper.make_tensor_value_info("output_f32", TensorProto.UNDEFINED, None), + ] + + graph = helper.make_graph(nodes, "MatMulNBits_Test", inputs, outputs, initializers) + model = helper.make_model(graph) + + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + + expected_shapes = [ + helper.make_tensor_value_info("output_f32", TensorProto.FLOAT, ["x", 2, 3, 4]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + class TestSymbolicShapeInferenceForSlice(unittest.TestCase): def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim): diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 58c185f818df..eacd41e6b9c6 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -31,7 +31,7 @@ #include "test_fixture.h" #include "utils.h" #include "custom_op_utils.h" -#include "core/common/gsl.h" +#include #ifdef _WIN32 #include diff --git a/onnxruntime/test/shared_lib/test_nontensor_types.cc b/onnxruntime/test/shared_lib/test_nontensor_types.cc index e8160d1619cb..5fa4fb31e1c9 100644 --- a/onnxruntime/test/shared_lib/test_nontensor_types.cc +++ b/onnxruntime/test/shared_lib/test_nontensor_types.cc @@ -9,7 +9,7 @@ #include "core/session/onnxruntime_cxx_api.h" #include "test_allocator.h" -#include "core/common/gsl.h" +#include #include "gmock/gmock.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/testdata/clip_div_shared_initializer.onnx b/onnxruntime/test/testdata/clip_div_shared_initializer.onnx new file mode 100644 index 000000000000..223d2d2febbb Binary files /dev/null and b/onnxruntime/test/testdata/clip_div_shared_initializer.onnx differ diff --git a/onnxruntime/test/testdata/clip_div_shared_initializer.py b/onnxruntime/test/testdata/clip_div_shared_initializer.py new file mode 100644 index 000000000000..e3c4ab438b0f --- /dev/null +++ b/onnxruntime/test/testdata/clip_div_shared_initializer.py @@ -0,0 +1,33 @@ +from onnx import TensorProto, checker, helper, save + +graph_proto = helper.make_graph( + [ + helper.make_node( + "Clip", + inputs=["input_0", "initializer_0", "initializer_1"], + outputs=["clip_output"], + name="clip", + ), + helper.make_node( + "Div", + inputs=["clip_output", "initializer_1"], + outputs=["output_0"], + name="div", + ), + ], + "Main_graph", + [ + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [3, 2]), + ], + [ + helper.make_tensor_value_info("output_0", TensorProto.FLOAT, [3, 2]), + ], + [ + helper.make_tensor("initializer_0", TensorProto.FLOAT, [], [0.0]), + helper.make_tensor("initializer_1", TensorProto.FLOAT, [], [6.0]), + ], +) + +model = helper.make_model(graph_proto) +checker.check_model(model, True) +save(model, "clip_div_shared_initializer.onnx") diff --git a/onnxruntime/test/testdata/conv.int4_weights.qdq.onnx b/onnxruntime/test/testdata/conv.int4_weights.qdq.onnx new file mode 100644 index 000000000000..56f965d0b410 Binary files /dev/null and b/onnxruntime/test/testdata/conv.int4_weights.qdq.onnx differ diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc index 6a2d27ee9517..2d7d8e7cc735 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "my_ep_factory.h" #include "my_execution_provider.h" -#include "core/common/gsl.h" +#include #include "core/providers/shared/common.h" #include #include "core/framework/provider_options_utils.h" diff --git a/onnxruntime/test/testdata/make_conv_int4_weights_model.py b/onnxruntime/test/testdata/make_conv_int4_weights_model.py new file mode 100644 index 000000000000..004342b53102 --- /dev/null +++ b/onnxruntime/test/testdata/make_conv_int4_weights_model.py @@ -0,0 +1,98 @@ +import numpy as np +import onnx + +from onnxruntime.quantization import CalibrationDataReader, QuantType, quantize +from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config + +INPUT0_SHAPE = (1, 3, 8, 8) +INPUT0_NAME = "input_0" + + +def create_f32_model(): + input_0 = onnx.helper.make_tensor_value_info(INPUT0_NAME, onnx.TensorProto.FLOAT, INPUT0_SHAPE) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, None) + weight_data = [ + [ + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + ], + [ + [[-1.0, -0.8, -0.6], [-0.1333, 0.0, -0.1333], [0.6, 0.8, 1.0]], # range = 2.0, scale = 2.0/15, zp = -3 + [[-1.0, -0.8, -0.6], [-0.1333, 0.0, -0.1333], [0.6, 0.8, 1.0]], # range = 2.0, scale = 2.0/15, zp = -3 + [[-1.0, -0.8, -0.6], [-0.1333, 0.0, -0.1333], [0.6, 0.8, 1.0]], # range = 2.0, scale = 2.0/15, zp = -3 + ], + [ + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + ], + [ + [[-1.0, -0.8, -0.6], [-0.1333, 0.0, -0.1333], [0.6, 0.8, 1.0]], # range = 2.0, scale = 2.0/15, zp = -3 + [[-1.0, -0.8, -0.6], [-0.1333, 0.0, -0.1333], [0.6, 0.8, 1.0]], # range = 2.0, scale = 2.0/15, zp = -3 + [[-1.0, -0.8, -0.6], [-0.1333, 0.0, -0.1333], [0.6, 0.8, 1.0]], # range = 2.0, scale = 2.0/15, zp = -3 + ], + [ + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + [[-1.5, -1.0, -0.5], [-0.2, 0.0, 0.2], [0.5, 1.0, 1.5]], # range = 3.0, scale = 3.0/15, zp = 0 + ], + ] + weight = onnx.numpy_helper.from_array(np.array(weight_data, dtype=np.float32), "weight") + bias_data = [-10.0, -8.0, 0.0, 8.0, 10.0] + bias = onnx.numpy_helper.from_array(np.array(bias_data, dtype=np.float32), "bias") + + conv_node = onnx.helper.make_node("Conv", [INPUT0_NAME, "weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [conv_node], + "Convf32", + [input_0], + [output_0], + initializer=[weight, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + + return model + + +class DataReader(CalibrationDataReader): + def __init__(self): + self.enum_data = None + self.data_list = [] + + # Generate 10 random input values for calibration + for _ in range(10): + input_data = {INPUT0_NAME: np.random.random(INPUT0_SHAPE).astype(np.float32)} + self.data_list.append(input_data) + + self.datasize = len(self.data_list) + + def get_next(self): + if self.enum_data is None: + self.enum_data = iter(self.data_list) + return next(self.enum_data, None) + + def rewind(self): + self.enum_data = None + + +def create_qdq_model(model_f32): + # Use tensor quantization overrides to quantize Conv's weight input to 4 bits on axis 0. + init_overrides = {"weight": [{"quant_type": QuantType.QInt4, "axis": 0, "symmetric": True}]} + qnn_config = get_qnn_qdq_config( + model_f32, + DataReader(), + init_overrides=init_overrides, + activation_type=QuantType.QUInt16, + weight_type=QuantType.QUInt8, + ) + + quantize(model_f32, "conv.int4_weights.qdq.onnx", qnn_config) + + +if __name__ == "__main__": + model_f32 = create_f32_model() + create_qdq_model(model_f32) diff --git a/onnxruntime/test/util/include/test_utils.h b/onnxruntime/test/util/include/test_utils.h index 48f0d7c2ab1f..f55295ac8aec 100644 --- a/onnxruntime/test/util/include/test_utils.h +++ b/onnxruntime/test/util/include/test_utils.h @@ -11,7 +11,7 @@ #include #include -#include "core/common/gsl.h" +#include #include "core/framework/execution_provider.h" #include "core/framework/framework_common.h" #include "core/framework/ort_value.h" diff --git a/onnxruntime/tool/etw/eparser.cc b/onnxruntime/tool/etw/eparser.cc index 526ba6de8196..ff0348d721eb 100644 --- a/onnxruntime/tool/etw/eparser.cc +++ b/onnxruntime/tool/etw/eparser.cc @@ -7,7 +7,7 @@ // Get the length of the property data. For MOF-based events, the size is inferred from the data type // of the property. For manifest-based events, the property can specify the size of the property value -// using the length attribute. The length attribue can specify the size directly or specify the name +// using the length attribute. The length attribute can specify the size directly or specify the name // of another property in the event data that contains the size. If the property does not include the // length attribute, the size is inferred from the data type. The length will be zero for variable // length, null-terminated strings and structures. @@ -16,7 +16,7 @@ DWORD GetPropertyLength(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, // Get the size of the array. For MOF-based events, the size is specified in the declaration or using // the MAX qualifier. For manifest-based events, the property can specify the size of the array -// using the count attribute. The count attribue can specify the size directly or specify the name +// using the count attribute. The count attribute can specify the size directly or specify the name // of another property in the event data that contains the size. DWORD GetArraySize(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, PUSHORT ArraySize); diff --git a/orttraining/orttraining/core/framework/torch/custom_function_register.h b/orttraining/orttraining/core/framework/torch/custom_function_register.h index 67a991ea2cce..762258a45221 100644 --- a/orttraining/orttraining/core/framework/torch/custom_function_register.h +++ b/orttraining/orttraining/core/framework/torch/custom_function_register.h @@ -9,9 +9,7 @@ #include #include -namespace onnxruntime { -namespace language_interop_ops { -namespace torch { +namespace onnxruntime::language_interop_ops::torch { typedef std::vector (*CustomFunctionRunnerType)(const char* func_name_char, void* callback, @@ -124,6 +122,4 @@ class OrtTorchFunctionPool final { std::mutex mutex_; }; -} // namespace torch -} // namespace language_interop_ops -} // namespace onnxruntime +} // namespace onnxruntime::language_interop_ops::torch diff --git a/orttraining/orttraining/core/framework/torch/dlpack_python.cc b/orttraining/orttraining/core/framework/torch/dlpack_python.cc index d512dc72a438..f9b237f05125 100644 --- a/orttraining/orttraining/core/framework/torch/dlpack_python.cc +++ b/orttraining/orttraining/core/framework/torch/dlpack_python.cc @@ -3,10 +3,7 @@ #include "orttraining/core/framework/torch/dlpack_python.h" -namespace onnxruntime { -namespace training { -namespace framework { -namespace torch { +namespace onnxruntime::training::framework::torch { static void DlpackCapsuleDestructor(PyObject* data) { DLManagedTensor* dlmanged_tensor = reinterpret_cast( @@ -35,7 +32,4 @@ OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor) { return ort_value; } -} // namespace torch -} // namespace framework -} // namespace training -} // namespace onnxruntime +} // namespace onnxruntime::training::framework::torch diff --git a/orttraining/orttraining/core/framework/torch/dlpack_python.h b/orttraining/orttraining/core/framework/torch/dlpack_python.h index 37bae2ab3702..9b641971dcea 100644 --- a/orttraining/orttraining/core/framework/torch/dlpack_python.h +++ b/orttraining/orttraining/core/framework/torch/dlpack_python.h @@ -8,10 +8,7 @@ #include "core/dlpack/dlpack_converter.h" #include "orttraining/core/framework/torch/python_common.h" -namespace onnxruntime { -namespace training { -namespace framework { -namespace torch { +namespace onnxruntime::training::framework::torch { // Allocate a new Capsule object, which takes the ownership of OrtValue. // Caller is responsible for releasing. @@ -22,7 +19,4 @@ PyObject* ToDlpack(OrtValue ort_value); // create a OrtValue. This function calls DlpackToOrtValue(...) to do the conversion. OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor); -} // namespace torch -} // namespace framework -} // namespace training -} // namespace onnxruntime +} // namespace onnxruntime::training::framework::torch diff --git a/orttraining/orttraining/core/framework/torch/torch_proxy.h b/orttraining/orttraining/core/framework/torch/torch_proxy.h index 450a5048aea4..b80acd6c4791 100644 --- a/orttraining/orttraining/core/framework/torch/torch_proxy.h +++ b/orttraining/orttraining/core/framework/torch/torch_proxy.h @@ -22,7 +22,7 @@ namespace torch { // For handling temporary PyObject pointer newly created with Py_XXX APIs, here is our practice: // Convention: // Wrap those PyObject* in format of "PythonObjectPtr(Py_XXX(), PythonObjectDeleter)". -// Explaination: +// Explanation: // That means, for the PyObject* created by Py_XXX(), its refcnt will be decreased by one // in the PythonObjectDeleter which is triggered once lifetime of PythonObjectPtr instance // ends. diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.cc b/orttraining/orttraining/core/graph/gradient_builder_base.cc index d57675e8b8e2..4262acb63658 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_base.cc @@ -63,16 +63,20 @@ void ComputeBroadcastBackwardAxes( auto A_dim = A_dims[i].dim_param(), B_dim = B_dims[j].dim_param(); if (A_dim != B_dim) { - LOGS_DEFAULT(INFO) << "Gradient building for node " << node_name << ": symbolic dimension expects to match. " - << "A_dims:" << ToString(A_dims) << ", B_dims:" << ToString(B_dims) << " This is a relaxing case, and the kernel might run into problem later if A_dims and B_dims turns out not broadcastable."; + LOGS_DEFAULT(INFO) + << "Gradient building for node " << node_name << ": symbolic dimension expects to match. " + << "A_dims:" << ToString(A_dims) << ", B_dims:" << ToString(B_dims) + << " This is a relaxing case, and the kernel might run into problem later if A_dims and B_dims turns out not broadcastable."; } } else if (A_dims[i].has_dim_param() && B_dims[j].has_dim_value()) { auto A_dim = A_dims[i].dim_param(); auto B_dim = B_dims[j].dim_value(); if (B_dim != 1) { - LOGS_DEFAULT(INFO) << "Gradient building for node " << node_name << ": symbolic broadcasting expects the B_dimension to be 1. " - << "A_dims:" << ToString(A_dims) << ", B_dims:" << ToString(B_dims) << " This is a relaxing case, and the kernel might run into problem later if A_dims and B_dims turns out not broadcastable."; + LOGS_DEFAULT(INFO) + << "Gradient building for node " << node_name << ": symbolic broadcasting expects the B_dimension to be 1. " + << "A_dims:" << ToString(A_dims) << ", B_dims:" << ToString(B_dims) + << " This is a relaxing case, and the kernel might run into problem later if A_dims and B_dims turns out not broadcastable."; } else { if (B_axes) { B_axes->push_back(gsl::narrow_cast(k)); @@ -83,8 +87,10 @@ void ComputeBroadcastBackwardAxes( auto B_dim = B_dims[j].dim_param(); if (A_dim != 1) { - LOGS_DEFAULT(INFO) << "Gradient building for node " << node_name << ": symbolic broadcasting expects the A_dimension to be 1. " - << "A_dims:" << ToString(A_dims) << ", B_dims:" << ToString(B_dims) << " This is a relaxing case, and the kernel might run into problem later if A_dims and B_dims turns out not broadcastable."; + LOGS_DEFAULT(INFO) + << "Gradient building for node " << node_name << ": symbolic broadcasting expects the A_dimension to be 1. " + << "A_dims:" << ToString(A_dims) << ", B_dims:" << ToString(B_dims) + << " This is a relaxing case, and the kernel might run into problem later if A_dims and B_dims turns out not broadcastable."; } else { if (A_axes) { A_axes->push_back(gsl::narrow_cast(k)); diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.h b/orttraining/orttraining/core/graph/gradient_builder_base.h index 2d8a87f6d442..a4aa70c99eec 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.h +++ b/orttraining/orttraining/core/graph/gradient_builder_base.h @@ -225,7 +225,9 @@ class GradientBuilderBase { } int OnnxOpSetVersion() const { - return graph_ != nullptr && graph_->DomainToVersionMap().find(kOnnxDomain) != graph_->DomainToVersionMap().end() ? graph_->DomainToVersionMap().at(kOnnxDomain) : -1; + return graph_ != nullptr && graph_->DomainToVersionMap().find(kOnnxDomain) != graph_->DomainToVersionMap().end() + ? graph_->DomainToVersionMap().at(kOnnxDomain) + : -1; } template diff --git a/orttraining/orttraining/core/graph/loss_function_registry.h b/orttraining/orttraining/core/graph/loss_function_registry.h index 76242276080c..6c880b6a3d9f 100644 --- a/orttraining/orttraining/core/graph/loss_function_registry.h +++ b/orttraining/orttraining/core/graph/loss_function_registry.h @@ -18,7 +18,7 @@ struct LossFunctionUsingOperator : public ILossFunction { class LossFunctionRegistry : public GenericRegistry { public: - // Register a list of non-operator loss functions stacitally. + // Register a list of non-operator loss functions statically. void RegisterNonOperatorLossFunctions(); // Register a operator loss function. diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc index 0d4291a3b8b3..4b6a9a6e594c 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc @@ -508,7 +508,7 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev *embedding_node); // Add flatten pattern to each input node of the subgraph - // to flattern the shape of [batch_size, seqlen, ...] to [valid_token_count, ...] + // to flatten the shape of [batch_size, seqlen, ...] to [valid_token_count, ...] InsertFlattenPatternForInput(graph, *embedding_node, 1, squeeze_out_arg, logger); handled_input_count++; for (auto& node : candidate_inputs) { diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h index cc3c90dac2d5..a09ee75c73aa 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h @@ -25,7 +25,7 @@ namespace onnxruntime { * * This transformer is implemented in the following steps: * 1. Iterate the graph and find the Embedding node that matches these requirements: - * 1.1 The 2nd input is a graph input and its rank > 2, with the first two dimensions, are: + * 1.1 Following a PythonOp(FlagAndPrintDensity) node, and its rank > 2, with the first two dimensions, are: * [batch_size, sequence_length]. Both dimensions can be symbolic or concrete dim values. * 1.2 The 3rd input(padding idx) is a scalar constant initializer, and should >= 0. * 2. Append embedding node in node_to_scan_list. @@ -54,6 +54,8 @@ namespace onnxruntime { * \ \ / / / * \_________________\_________________________/________________/______________________/ * | + * PythonOp (FlagAndPrintDensity) + * | * ATen:embedding * | * - - - - - - - - - - - -| @@ -68,7 +70,7 @@ namespace onnxruntime { * output * * - * After the transformation: + * After the transformation (PythonOp (FlagAndPrintDensity) is removed unless user need to print density for each step): * * input_ids [batch_size, seq_length] * | \ diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/sceloss_compute_optimization.h b/orttraining/orttraining/core/optimizer/compute_optimizer/sceloss_compute_optimization.h index cf3470611589..2204724bacf5 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/sceloss_compute_optimization.h +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/sceloss_compute_optimization.h @@ -32,10 +32,10 @@ namespace onnxruntime { * 2. Its 2nd output (log_prob) MUST NOT be a graph output and MUST NOT be consumed by other nodes. * 3. Its ignore_index exists and is a constant scalar value. * 4. Its 2nd input label's input node is not a `ShrunkGather` node (to avoid this transformer duplicated applied). - * 5. Its 2nd input label is 1) a graph input or 2) output of a Reshape node taking a graph input as its data input. + * 5. Following PythonOp (FlagAndPrintDensity). * * - * After the transformation: + * After the transformation (PythonOp (FlagAndPrintDensity) is removed unless user need to print density for each step): * labels [token_count] * \_______ * \ \ diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.h b/orttraining/orttraining/core/optimizer/graph_transformer_utils.h index 011a007ab6e7..5196f5256fed 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.h +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include "core/optimizer/graph_transformer.h" #include "orttraining/core/optimizer/graph_transformer_config.h" diff --git a/orttraining/orttraining/core/optimizer/insert_output_rewriter.h b/orttraining/orttraining/core/optimizer/insert_output_rewriter.h index 5e4bf5c5ce7a..de000e00f1bf 100644 --- a/orttraining/orttraining/core/optimizer/insert_output_rewriter.h +++ b/orttraining/orttraining/core/optimizer/insert_output_rewriter.h @@ -7,7 +7,7 @@ namespace onnxruntime { -// Rewrite rule that insert an addtional output to the matched node. +// Rewrite rule that insert an additional output to the matched node. class InsertMaxPoolOutput : public RewriteRule { public: InsertMaxPoolOutput() noexcept @@ -24,7 +24,7 @@ class InsertMaxPoolOutput : public RewriteRule { Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; -// Rewrite rule that insert an addtional output to the matched node. +// Rewrite rule that insert an additional output to the matched node. // Adding this second output to expose FW intermediate result for speeding up BW computation class InsertSoftmaxCrossEntropyLossOutput : public RewriteRule { public: diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 088fd345135d..8d110c692751 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -60,9 +60,9 @@ using OpsetToIgnorableIndicesMap = InlinedHashMap; * Most recent revisited for ONNX v1.15.0 release - https://github.com/onnx/onnx/blob/b86cc54efce19530fb953e4b21f57e6b3888534c/docs/Operators.md * * We defined supported list explicitly instead of using a excluding list for the following reasons: - * 1. Some ops generate indeterministic results (for example using random number generator). We need evaluate whether + * 1. Some ops generate non-deterministic results (for example using random number generator). We need evaluate whether * this is a problem for recompute before adding the support, instead of fixing this after we find and try to - * fix convergence issues (which will be very hard if we have multiple indeterministic operators by default supported.) + * fix convergence issues (which will be very hard if we have multiple non-deterministic operators by default supported.) * 2. Some ops schema will be changed in new opsets, we need also check manually whether it is applicable to recompute * or not. * 3. Some ops are not supported in older opsets, we need to check whether it is applicable to recompute or not. diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h index 5aa05b0f02e0..d87706ea9806 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h @@ -151,7 +151,7 @@ class NodeRecomputePlan : public NodeOptimizationPlanBase { * @param compromise_stashed_activation Whether to compromise stashed activation, e.g. if we cannot find a * recomputable subgraph to save a stashed activation, we can compromise to find a recomputable subgraph to reduce the * size of stashed activation. - * @param can_compromise_stashed_activation A bool return value, to indicate there is opportunaties for finding a + * @param can_compromise_stashed_activation A bool return value, to indicate there are opportunities for finding a * compromised subgraph. */ std::unique_ptr CheckNodeForRecompute(const GraphViewer& graph_viewer, diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc index a4fbacc8a1f4..dd585068a23c 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc @@ -88,7 +88,7 @@ std::tuple IsResidualNodeArg(const GraphViewer& ----------------------| | | | | - | SimplifiedLayerNormalization (layer boudary node) + | SimplifiedLayerNormalization (layer boundary node) | | | | | MistralAttention diff --git a/orttraining/orttraining/core/optimizer/qdq_fusion.h b/orttraining/orttraining/core/optimizer/qdq_fusion.h index 722565cffa80..3bf9a7909f7e 100644 --- a/orttraining/orttraining/core/optimizer/qdq_fusion.h +++ b/orttraining/orttraining/core/optimizer/qdq_fusion.h @@ -14,7 +14,7 @@ This transformer will be used during QAT (Quantization Aware Training). For QAT an onnx graph that has Q->DQ nodes needs to be made ready for training. The output of the Q node is a quantized type. Backpropagation on quantized type is not supported in ort. So, we replace the occurrences of Q->DQ with FakeQuant which internally will perform the -Q->DQ opeeration and at the same time can support backpropagation. +Q->DQ operation and at the same time can support backpropagation. from: x (fp32) diff --git a/orttraining/orttraining/core/optimizer/transpose_replacement.h b/orttraining/orttraining/core/optimizer/transpose_replacement.h index c38e40233982..d2bbe2fdcfc1 100644 --- a/orttraining/orttraining/core/optimizer/transpose_replacement.h +++ b/orttraining/orttraining/core/optimizer/transpose_replacement.h @@ -13,10 +13,10 @@ namespace onnxruntime { Transpose is equivalent to a Reshape if: empty dimensions (which dim_value=1) can change place, not empty dimensions must be in - the same order in the permuted tenosr. + the same order in the permuted tensor. Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). -This Rewrite rule replaces Transpose which meets the requirments with Reshape. +This Rewrite rule replaces Transpose which meets the requirements with Reshape. Because Transpose need memory copy while Reshape needn't, this replacement can save overhead for memory copy. It is attempted to be triggered only on nodes with op type "Transpose". diff --git a/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py b/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py index ff128c4da425..12eba90170fb 100644 --- a/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py +++ b/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py @@ -26,7 +26,7 @@ def override_function(m_self): # noqa: N805 from onnxruntime.training.ortmodule.torch_cpp_extensions import fused_ops - warnings.warn("Apex AMP fp16_optimizer functions are overrided with faster implementation.", UserWarning) + warnings.warn("Apex AMP fp16_optimizer functions are overridden with faster implementation.", UserWarning) # Implementation adapted from https://github.com/NVIDIA/apex/blob/082f999a6e18a3d02306e27482cc7486dab71a50/apex/amp/_process_optimizer.py#L161 def post_backward_with_master_weights(self, scaler): diff --git a/orttraining/orttraining/python/training/optim/_ds_modifier.py b/orttraining/orttraining/python/training/optim/_ds_modifier.py index 20f4f814e547..55e2e0843213 100644 --- a/orttraining/orttraining/python/training/optim/_ds_modifier.py +++ b/orttraining/orttraining/python/training/optim/_ds_modifier.py @@ -140,7 +140,7 @@ def can_be_modified(self): ) def override_function(self): - warnings.warn("DeepSpeed fp16_optimizer functions are overrided with faster implementation.", UserWarning) + warnings.warn("DeepSpeed fp16_optimizer functions are overridden with faster implementation.", UserWarning) def get_grad_norm_direct(target, gradients, params, norm_type=2): from onnxruntime.training.ortmodule.torch_cpp_extensions import fused_ops diff --git a/orttraining/orttraining/python/training/optim/_megatron_modifier.py b/orttraining/orttraining/python/training/optim/_megatron_modifier.py index 707727120c5c..702eba77cb74 100644 --- a/orttraining/orttraining/python/training/optim/_megatron_modifier.py +++ b/orttraining/orttraining/python/training/optim/_megatron_modifier.py @@ -27,7 +27,7 @@ def can_be_modified(self): ) def override_function(self): - warnings.warn("Megatron-LM fp16_optimizer functions are overrided with faster implementation.", UserWarning) + warnings.warn("Megatron-LM fp16_optimizer functions are overridden with faster implementation.", UserWarning) def clip_master_grads(target, max_norm, norm_type=2): """ diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 75512cb8e8c8..a8590cea2288 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -238,7 +238,7 @@ def native_group_norm_gradient(): # PyTorch removed related backward functions with "vec" overload name since 1.13. The functions with no overload name -# are available for all versions, though they are not that convienent to use. +# are available for all versions, though they are not that convenient to use. def _upsample_gradient(backward_fn, dims): scales = ["" for _ in range(dims)] if "bicubic" in backward_fn: diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 0bd29b8d155c..10e7f60b7da0 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -437,7 +437,7 @@ def permute_and_reshape_tensor( shape_tensor, ): # If matmul_output_axes and contraction_axes are contiguous in input tensor, - # we can move Reshape to before Transpose, so it's possible that the Transpoase is fused to MatMul. + # we can move Reshape to before Transpose, so it's possible that the Transpose is fused to MatMul. # Otherwise, we have to Transpose first to move those axes together and then Reshape. is_matmul_output_axes_contiguous = is_axes_contiguous(matmul_output_axes) is_contraction_axes_contiguous = is_axes_contiguous(contraction_axes) @@ -525,7 +525,7 @@ def permute_and_reshape_tensor( @register_symbolic("einsum", torch_version_end="1.13.0") @parse_args("s", "v") -def einsum_pre_troch_113(g, equation, tensor_list): +def einsum_pre_torch_113(g, equation, tensor_list): return einsum_internal(g, equation, tensor_list) @@ -540,12 +540,12 @@ def einsum_internal(g, equation, tensor_list): num_ops = len(tensors) assert num_ops > 0 - # Doesn't support implicit output is ellipsis or more than 2 oprands for now. - # Doesn't support ellipsis ('...') for now as not easy to get sizes of oprands. + # Doesn't support implicit output is ellipsis or more than 2 operands for now. + # Doesn't support ellipsis ('...') for now as not easy to get sizes of operands. if num_ops != 2 or equation.find("->") == -1 or "." in equation: return g.op("Einsum", *tensors, equation_s=equation) - # Take "ks,ksm->sm" as example. After prcoess inputs, + # Take "ks,ksm->sm" as example. After process inputs, # lhs_labels = [k,s], rhs_labels = [k,s,m], result_labels = [s,m]. lhs_labels, rhs_labels, result_labels = parse_equation(equation) diff --git a/orttraining/orttraining/python/training/ortmodule/_execution_agent.py b/orttraining/orttraining/python/training/ortmodule/_execution_agent.py index 84d7bf641096..047cd4c59d63 100644 --- a/orttraining/orttraining/python/training/ortmodule/_execution_agent.py +++ b/orttraining/orttraining/python/training/ortmodule/_execution_agent.py @@ -102,7 +102,7 @@ def __init__( ): """ :param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string - :param fw_feed_names: Feed names for foward pass. + :param fw_feed_names: Feed names for forward pass. :param fw_outputs_device_info: Device info for fetches in forward pass. :param bw_fetches_names: Fetch names for backward pass. :param bw_outputs_device_info: Device info for fetches in backward pass. diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 18999ce2fa1a..c1ff62a5faea 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -312,7 +312,7 @@ def _initialize_graph_builder(self, post_export_processed_model_info: PostExport def __getstate__(self): state = copy.copy(self.__dict__) - # Remove any re-contructible/pybound object from the state + # Remove any re-constructible/pybound object from the state serialization_deny_list = [ "_onnx_models", "_graph_builder", diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 80bb00e0c3ac..22627749c316 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -871,7 +871,7 @@ def _get_exported_model( enable_zero_stage3_support, stage3_param_handle, flattened_module ): required_export_kwargs = { - "input_names": model_info_for_export.onnx_graph_input_names, # did not contains paramerter as its input yet + "input_names": model_info_for_export.onnx_graph_input_names, # did not contains parameters as its input yet "output_names": output_names, "opset_version": onnx_opset_version, "do_constant_folding": False, diff --git a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py index 93d151ea1217..fcab32f4356b 100644 --- a/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py +++ b/orttraining/orttraining/python/training/ortmodule/_mem_efficient_grad_mgmt.py @@ -130,7 +130,7 @@ def _create_param_retrieval_function( Args: trainable_named_params: The trainable named parameters. - param_trigger: The trigger tensor for pulling the weights. param_trigger is pre-alloced just once + param_trigger: The trigger tensor for pulling the weights. param_trigger is pre-allocated just once before model execution, later it will be reused by each iteration. This could save the unnecessary overhead allocating for each iteration run. diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 3708343a228f..d5d5ce672224 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -457,8 +457,8 @@ def _create_execution_agent(self): def __getstate__(self): state = super().__getstate__() - # Only top level classes are pickleable. So, _ORTModuleFunction is - # not pickleable. So, let's not pickle it, and redefine it when + # Only top level classes are picklable. So, _ORTModuleFunction is + # not picklable. So, let's not pickle it, and redefine it when # loading the state. del state["_forward_class"] return state diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py index e6e5ce56773e..fbd98675aebe 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py @@ -41,13 +41,13 @@ class GraphMatcher: * Second bool indicates it's producer node or consumer node for source node. * There is a list to describe the edge infos of this node to other nodes, each edge is a tuple with 3 integers, first integer is the index of the target node in the list, second integer is the output index of the edge, - and thrid integer is the input index of the edge. + and third integer is the input index of the edge. For each entry, GraphMatcher used the first edge to lookup target node, and try to use make sure the sug-graph also matches rest edge infos. Note that when lookup target node, it will only take the first matched node as target node. For example, if a source - node has multiple "MatMul" consumers nodes comsuming same output, only the first "MatMul" node will be returned. + node has multiple "MatMul" consumers nodes consuming same output, only the first "MatMul" node will be returned. You need to avoid using such confusing edge info as the first edge info for node lookup. Try to use other edge to avoid such confusion if possible. """ diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 9145fb1712e8..2d036a5abcb5 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -75,7 +75,7 @@ def __init__(self, log_level): def _extract_info(self, log_level): # get the log_level from os env variable - # OS environment variable log level superseeds the locally provided one + # OS environment variable log level supersedes the locally provided one self._validate(log_level) log_level = LogLevel[os.getenv(_LoggingOptions._log_level_environment_key, log_level.name)] return log_level @@ -197,7 +197,7 @@ class _MemoryOptimizationLevel(IntFlag): USER_SPECIFIED = 0 # Fully respect user-specified config TRANSFORMER_LAYERWISE_RECOMPUTE = ( - 1 # Enable all recomputable subgraphs (excluding compromised recomptable graphs) per layer + 1 # Enable all recomputable subgraphs (excluding compromised recomputable graphs) per layer ) TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE = 2 # Enable all recomputable subgraphs per layer diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index ba6f7c2d0c03..b291bfb2ba03 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -124,7 +124,7 @@ def forward(self, *inputs, **kwargs): The first call to forward performs setup and checking steps. During this call, ORTModule determines whether the module can be trained with ONNX Runtime. For this reason, the first forward call execution takes longer than subsequent calls. - Execution is interupted if ONNX Runtime cannot process the model for training. + Execution is interrupted if ONNX Runtime cannot process the model for training. Args: inputs: positional, variable positional inputs to the PyTorch module's forward method. diff --git a/orttraining/orttraining/test/training_api/common/synthetic_data_loader.h b/orttraining/orttraining/test/training_api/common/synthetic_data_loader.h index 542acb2e337a..2340104840a6 100644 --- a/orttraining/orttraining/test/training_api/common/synthetic_data_loader.h +++ b/orttraining/orttraining/test/training_api/common/synthetic_data_loader.h @@ -11,7 +11,7 @@ #pragma once -#include "core/common/gsl.h" +#include #include diff --git a/orttraining/orttraining/training_ops/cpu/activation/activations_grad.cc b/orttraining/orttraining/training_ops/cpu/activation/activations_grad.cc index f42fa0c4dc98..d3f2f9c7a876 100644 --- a/orttraining/orttraining/training_ops/cpu/activation/activations_grad.cc +++ b/orttraining/orttraining/training_ops/cpu/activation/activations_grad.cc @@ -3,7 +3,7 @@ #include "orttraining/training_ops/cpu/activation/activations_grad.h" -#include "core/common/gsl.h" +#include #if defined(_MSC_VER) #pragma warning(push) diff --git a/orttraining/orttraining/training_ops/cpu/loss/cross_entropy.cc b/orttraining/orttraining/training_ops/cpu/loss/cross_entropy.cc index 9aad7edf68bf..3dd52f36ec72 100644 --- a/orttraining/orttraining/training_ops/cpu/loss/cross_entropy.cc +++ b/orttraining/orttraining/training_ops/cpu/loss/cross_entropy.cc @@ -7,7 +7,7 @@ #include "core/providers/common.h" #include #include "core/providers/cpu/math/matmul_helper.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace contrib { diff --git a/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc b/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc index b3c04af1a265..c74bf06a77d6 100644 --- a/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc +++ b/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc @@ -10,7 +10,7 @@ #include "core/providers/cpu/controlflow/scan_utils.h" #include "orttraining/training_ops/cpu/loss/cross_entropy.h" #include "orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.h" -#include "core/common/gsl.h" +#include namespace onnxruntime { namespace contrib { diff --git a/orttraining/orttraining/training_ops/cpu/op_gradients.cc b/orttraining/orttraining/training_ops/cpu/op_gradients.cc index c3476161c1e5..f4b9c08bd90c 100644 --- a/orttraining/orttraining/training_ops/cpu/op_gradients.cc +++ b/orttraining/orttraining/training_ops/cpu/op_gradients.cc @@ -3,7 +3,7 @@ #include "orttraining/training_ops/cpu/op_gradients.h" -#include "core/common/gsl.h" +#include #include "core/mlas/inc/mlas.h" #include "core/providers/common.h" #include "core/providers/cpu/math/element_wise_ops.h" diff --git a/orttraining/orttraining/training_ops/cpu/tensor/split.cc b/orttraining/orttraining/training_ops/cpu/tensor/split.cc index d361f3ec64e3..1edfdac7631e 100644 --- a/orttraining/orttraining/training_ops/cpu/tensor/split.cc +++ b/orttraining/orttraining/training_ops/cpu/tensor/split.cc @@ -3,7 +3,7 @@ #include "orttraining/training_ops/cpu/tensor/split.h" -#include "core/common/gsl.h" +#include #include "core/common/narrow.h" #include "core/providers/common.h" #include "core/util/math.h" diff --git a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc index 7bd759e8976c..f3feef4391bb 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc @@ -35,7 +35,8 @@ struct PadAndUnflattenFunctor { typedef typename ToCudaType::MappedType CudaT; const CudaT* input_data = reinterpret_cast(input_tensor.Data()); - CUDA_CALL_THROW(cudaMemset(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT))); + CUDA_CALL_THROW(cudaMemsetAsync(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT), + stream)); PadAndUnflattenImpl(stream, input_element_count, output_element_stride_fdm, index_value_upper_bound, input_data, indices_tensor.Data(), reinterpret_cast(output_tensor.MutableData())); @@ -48,6 +49,7 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const { const Tensor* input_tensor = context->Input(0); const Tensor* indices_tensor = context->Input(1); const Tensor* unflatten_dims_tensor = context->Input(2); // Parse the 1-D shape tensor. + ORT_ENFORCE(unflatten_dims_tensor->Shape().NumDimensions() == 1, "unflatten_dims_tensor tensor must be 1-D.", unflatten_dims_tensor->Shape().NumDimensions()); ORT_ENFORCE(unflatten_dims_tensor->Shape().Size() == 2, diff --git a/tools/ci_build/gen_def.py b/tools/ci_build/gen_def.py index b53fb3365912..fe47d8dbe57f 100755 --- a/tools/ci_build/gen_def.py +++ b/tools/ci_build/gen_def.py @@ -79,6 +79,7 @@ def parse_arguments(): "cann", "dnnl", "tensorrt", + "azure", ): file.write(f"#include \n") file.write("void* GetFunctionEntryByName(const char* name){\n") diff --git a/tools/ci_build/github/android/default_full_aar_build_settings.json b/tools/ci_build/github/android/default_full_aar_build_settings.json index 467f7048942c..b0eff7581267 100644 --- a/tools/ci_build/github/android/default_full_aar_build_settings.json +++ b/tools/ci_build/github/android/default_full_aar_build_settings.json @@ -8,6 +8,7 @@ "android_min_sdk_version": 21, "android_target_sdk_version": 24, "build_params": [ + "--enable_lto", "--android", "--parallel", "--cmake_generator=Ninja", diff --git a/tools/ci_build/github/android/training_full_aar_build_settings.json b/tools/ci_build/github/android/training_full_aar_build_settings.json index 76cb9f0b17a8..013804e2d63e 100644 --- a/tools/ci_build/github/android/training_full_aar_build_settings.json +++ b/tools/ci_build/github/android/training_full_aar_build_settings.json @@ -8,6 +8,7 @@ "android_min_sdk_version": 21, "android_target_sdk_version": 24, "build_params": [ + "--enable_lto", "--android", "--parallel", "--cmake_generator=Ninja", diff --git a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml index 54e83b03aa61..a3e3a202672a 100644 --- a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml @@ -85,6 +85,7 @@ stages: - script: | python3 tools/ci_build/build.py \ + --enable_lto \ --android \ --build_dir build \ --android_sdk_path $ANDROID_HOME \ @@ -169,6 +170,7 @@ stages: - script: | python3 tools/ci_build/build.py \ + --enable_lto \ --android \ --build_dir build_nnapi \ --android_sdk_path $ANDROID_HOME \ @@ -264,6 +266,7 @@ stages: - script: | python3 tools/ci_build/build.py \ + --enable_lto \ --android \ --build_dir build \ --android_sdk_path $ANDROID_HOME \ @@ -329,6 +332,7 @@ stages: - script: | python3 tools/ci_build/build.py \ + --enable_lto \ --android \ --build_dir build_nnapi \ --android_sdk_path $ANDROID_HOME \ @@ -401,6 +405,7 @@ stages: - script: | python3 tools/ci_build/build.py \ + --enable_lto \ --android \ --build_dir build_nnapi \ --android_sdk_path $ANDROID_HOME \ diff --git a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml index 45d763384ee2..43043633365b 100644 --- a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml @@ -32,5 +32,5 @@ jobs: parameters: AgentPool : 'Linux-CPU-2019' JobName: 'Linux_CI_Dev' - RunDockerBuildArgs: '-o ubuntu20.04 -d openvino -v 2024.0.0 -x "--use_openvino CPU --build_wheel"' + RunDockerBuildArgs: '-o ubuntu22.04 -p 3.10 -d openvino -v 2024.0.0 -x "--use_openvino CPU --build_wheel"' TimeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 9bd5a81181e1..cf350704f835 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.163 + version: 1.0.165 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.163 + version: 1.0.165 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino index 45682c797bbb..dbd2076041b9 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino @@ -1,13 +1,13 @@ -ARG UBUNTU_VERSION=20.04 +ARG UBUNTU_VERSION=22.04 FROM ubuntu:${UBUNTU_VERSION} ARG OPENVINO_VERSION=2024.0.0 -ARG PYTHON_VERSION=3.9 +ARG PYTHON_VERSION=3.10 ADD scripts /tmp/scripts -RUN /tmp/scripts/install_ubuntu.sh -p ${PYTHON_VERSION} -d EdgeDevice && \ - /tmp/scripts/install_os_deps.sh -d EdgeDevice && \ - /tmp/scripts/install_python_deps.sh -p ${PYTHON_VERSION} -d EdgeDevice +RUN /tmp/scripts/install_ubuntu.sh -p $PYTHON_VERSION -d EdgeDevice +RUN /tmp/scripts/install_os_deps.sh -d EdgeDevice +RUN /tmp/scripts/install_python_deps.sh -p $PYTHON_VERSION -d EdgeDevice RUN apt update && apt install -y libnuma1 ocl-icd-libopencl1 && \ rm -rf /var/lib/apt/lists/* /tmp/scripts @@ -19,9 +19,9 @@ ENV IE_PLUGINS_PATH $INTEL_OPENVINO_DIR/runtime/lib/intel64 ENV DEBIAN_FRONTEND=noninteractive RUN cd /opt && mkdir -p intel && cd intel && \ - wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.0/linux/l_openvino_toolkit_ubuntu20_2024.0.0.14509.34caeefd078_x86_64.tgz && \ - tar xzf l_openvino_toolkit_ubuntu20_2024.0.0.14509.34caeefd078_x86_64.tgz && rm -rf l_openvino_toolkit_ubuntu20_2024.0.0.14509.34caeefd078_x86_64.tgz && \ - mv l_openvino_toolkit_ubuntu20_2024.0.0.14509.34caeefd078_x86_64 openvino_2024.0.0 && \ + wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.0/linux/l_openvino_toolkit_ubuntu22_2024.0.0.14509.34caeefd078_x86_64.tgz && \ + tar xzf l_openvino_toolkit_ubuntu22_2024.0.0.14509.34caeefd078_x86_64.tgz && rm -rf l_openvino_toolkit_ubuntu22_2024.0.0.14509.34caeefd078_x86_64.tgz && \ + mv l_openvino_toolkit_ubuntu22_2024.0.0.14509.34caeefd078_x86_64 openvino_2024.0.0 && \ cd $INTEL_OPENVINO_DIR/install_dependencies && ./install_openvino_dependencies.sh -y WORKDIR /root diff --git a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh index 9c2f02dd34bf..a98096342903 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh @@ -12,6 +12,7 @@ PYTHON_VER=${PYTHON_VER:=3.8} # Some Edge devices only have limited disk space, use this option to exclude some package DEVICE_TYPE=${DEVICE_TYPE:=Normal} +# shellcheck disable=SC2034 DEBIAN_FRONTEND=noninteractive echo 'debconf debconf/frontend select Noninteractive' | debconf-set-selections @@ -19,8 +20,6 @@ apt-get update && apt-get install -y software-properties-common lsb-release OS_VERSION=$(lsb_release -r -s) -SYS_LONG_BIT=$(getconf LONG_BIT) - PACKAGE_LIST="autotools-dev \ automake \ build-essential \ @@ -37,9 +36,8 @@ PACKAGE_LIST="autotools-dev \ gfortran \ python3-dev \ language-pack-en \ - liblttng-ust0 \ + liblttng-ust-dev \ libcurl4 \ - libssl1.1 \ libkrb5-3 \ libtinfo-dev \ libtinfo5 \ @@ -50,7 +48,7 @@ PACKAGE_LIST="autotools-dev \ unzip \ zip \ rsync libunwind8 libpng-dev libexpat1-dev \ - python3-setuptools python3-numpy python3-wheel python python3-pip python3-pytest \ + python3-setuptools python3-numpy python3-wheel python3-pip python3-pytest python3-distutils \ openjdk-11-jdk \ graphviz" @@ -59,7 +57,7 @@ if [ $DEVICE_TYPE = "Normal" ]; then PACKAGE_LIST="$PACKAGE_LIST libedit-dev libxml2-dev python3-packaging" fi -PACKAGE_LIST="$PACKAGE_LIST libicu66" +PACKAGE_LIST="$PACKAGE_LIST libicu-dev" apt-get install -y --no-install-recommends $PACKAGE_LIST @@ -67,8 +65,14 @@ locale-gen en_US.UTF-8 update-locale LANG=en_US.UTF-8 if [ "$OS_VERSION" = "20.04" ]; then + # The defaul version of python is 3.8 + major=$(echo $PYTHON_VER | cut -d. -f1) + minor=$(echo $PYTHON_VER | cut -d. -f2) + if [ "$major" -lt 3 ] || [ "$major" -eq 3 ] && [ "$minor" -lt 8 ]; then + PYTHON_VER="3.8" + fi if [ "$PYTHON_VER" != "3.8" ]; then - add-apt-repository -y ppa:deadsnakes/ppa + add-apt-repository -y ppa:deadsnakes/ppa apt-get update apt-get install -y --no-install-recommends \ python${PYTHON_VER} \ @@ -80,6 +84,23 @@ if [ "$OS_VERSION" = "20.04" ]; then #put at /usr/local/. Then there will be two pips. /usr/bin/python${PYTHON_VER} -m pip install --upgrade --force-reinstall pip==19.0.3 fi +elif [ "$OS_VERSION" = "22.04" ] ; then + # The defaul version of python is 3.10 + major=$(echo $PYTHON_VER | cut -d. -f1) + minor=$(echo $PYTHON_VER | cut -d. -f2) + if [ "$major" -lt 3 ] || [ "$major" -eq 3 ] && [ "$minor" -lt 10 ]; then + PYTHON_VER="3.10" + fi + if [ "$PYTHON_VER" != "3.10" ]; then + add-apt-repository -y ppa:deadsnakes/ppa + apt-get update + apt-get install -y --no-install-recommends \ + python${PYTHON_VER} \ + python${PYTHON_VER}-dev + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VER} 1 + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 2 + update-alternatives --set python3 /usr/bin/python${PYTHON_VER} + fi else exit 1 fi diff --git a/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_baseline.config b/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_baseline.config index 215bc4015acd..3f1691f47e70 100644 --- a/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_baseline.config +++ b/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_baseline.config @@ -3,6 +3,7 @@ "os": "android", "arch": "arm64-v8a", "build_params": [ + "--enable_lto", "--android", "--android_sdk_path=/android_home", "--android_ndk_path=/ndk_home", diff --git a/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_with_mobile_package_ops.config b/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_with_mobile_package_ops.config index 1348707a071c..dbebec5788dd 100644 --- a/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_with_mobile_package_ops.config +++ b/tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_with_mobile_package_ops.config @@ -3,6 +3,7 @@ "os": "android", "arch": "arm64-v8a", "build_params": [ + "--enable_lto", "--android", "--android_sdk_path=/android_home", "--android_ndk_path=/ndk_home", diff --git a/tools/ci_build/github/linux/run_dockerbuild.sh b/tools/ci_build/github/linux/run_dockerbuild.sh index 440752bc819d..9944861f519f 100755 --- a/tools/ci_build/github/linux/run_dockerbuild.sh +++ b/tools/ci_build/github/linux/run_dockerbuild.sh @@ -20,16 +20,16 @@ ORTMODULE_BUILD=false #Training only USE_CONDA=false ALLOW_RELEASED_ONNX_OPSET_ONLY_ENV="ALLOW_RELEASED_ONNX_OPSET_ONLY="$ALLOW_RELEASED_ONNX_OPSET_ONLY -echo "ALLOW_RELEASED_ONNX_OPSET_ONLY environment variable is set as "$ALLOW_RELEASED_ONNX_OPSET_ONLY_ENV +echo "ALLOW_RELEASED_ONNX_OPSET_ONLY environment variable is set as $ALLOW_RELEASED_ONNX_OPSET_ONLY_ENV" while getopts o:d:p:x:v:y:t:i:mue parameter_Option do case "${parameter_Option}" in -#yocto, ubuntu20.04 +#yocto, ubuntu22.04 o) BUILD_OS=${OPTARG};; #gpu, tensorrt or openvino. It is ignored when BUILD_OS is yocto. d) BUILD_DEVICE=${OPTARG};; -#python version: 3.6 3.7 (absence means default 3.6) +#python version: 3.8 3.9 3.10 3.11 3.12 (absence means default 3.8) p) PYTHON_VER=${OPTARG};; # "--build_wheel --use_openblas" x) BUILD_EXTR_PAR=${OPTARG};; @@ -48,9 +48,11 @@ m) INSTALL_DEPS_DISTRIBUTED_SETUP=true;; u) ORTMODULE_BUILD=true;; # install and use conda e) USE_CONDA=true;; +*) echo "Invalid option";; esac done +# shellcheck disable=SC2034 EXIT_CODE=1 DEFAULT_PYTHON_VER="3.8" @@ -62,7 +64,10 @@ if [[ -n "${IMAGE_CACHE_CONTAINER_REGISTRY_NAME}" ]]; then GET_DOCKER_IMAGE_CMD="${GET_DOCKER_IMAGE_CMD} --container-registry ${IMAGE_CACHE_CONTAINER_REGISTRY_NAME}" fi DOCKER_CMD="docker" - +# If BUILD_OS is ubuntu, then UBUNTU_VERSION is set to the version string after ubuntu +if [[ $BUILD_OS == ubuntu* ]]; then + UBUNTU_VERSION=${BUILD_OS#ubuntu} +fi NEED_BUILD_SHARED_LIB=true cd $SCRIPT_DIR/docker @@ -96,10 +101,9 @@ elif [ $BUILD_DEVICE = "gpu" ]; then --docker-build-args="--build-arg BASEIMAGE=nvcr.io/nvidia/cuda:11.8.0-cudnn8-devel-${BUILD_OS} --build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=${PYTHON_VER} --build-arg INSTALL_DEPS_EXTRA_ARGS=\"${INSTALL_DEPS_EXTRA_ARGS}\" --build-arg USE_CONDA=${USE_CONDA} --network=host" \ --dockerfile Dockerfile.ubuntu_gpu_training --context . elif [[ $BUILD_DEVICE = "openvino"* ]]; then - BUILD_ARGS="--build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=3.8" + BUILD_ARGS="--build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=${PYTHON_VER} --build-arg OPENVINO_VERSION=${OPENVINO_VERSION} --build-arg UBUNTU_VERSION=${UBUNTU_VERSION}" IMAGE="$BUILD_OS-openvino" DOCKER_FILE=Dockerfile.ubuntu_openvino - BUILD_ARGS+=" --build-arg OPENVINO_VERSION=${OPENVINO_VERSION}" $GET_DOCKER_IMAGE_CMD --repository "onnxruntime-$IMAGE" \ --docker-build-args="${BUILD_ARGS}" \ --dockerfile $DOCKER_FILE --context . diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.h b/winml/lib/Api.Ort/OnnxruntimeEngine.h index eae7dc37941c..88945b75c75e 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngine.h +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.h @@ -3,7 +3,7 @@ #include "iengine.h" #include "UniqueOrtPtr.h" -#include "core/common/gsl.h" +#include #include #include