From 56b36a58baa5738af7e1beb7dd4aa4acf7e54d9e Mon Sep 17 00:00:00 2001 From: Chen Feiyue <69809761+chenfeiyue-cfy@users.noreply.github.com> Date: Sat, 29 Jun 2024 12:48:34 +0800 Subject: [PATCH] Initial PR for VSINPU execution provider (#20903) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description -It is an initial PR for VSINPU execution provider ### Motivation and Context - For support VeriSilicon hardware - TIM-VX(Tensor Interface Module) (https://github.com/VeriSilicon/TIM-VX) is an integrated software solution by Verisilicon for our hardware(A311D/i.MX 8M Plus etc.) design, it is easy to use Verisilicon’s hardware by simply connecting onnxruntime with the TIM-VX API by this VSINPU execution provider. --- cmake/CMakeLists.txt | 6 + cmake/onnxruntime.cmake | 1 + cmake/onnxruntime_providers.cmake | 32 ++ cmake/onnxruntime_unittests.cmake | 5 + include/onnxruntime/core/graph/constants.h | 1 + .../vsinpu/vsinpu_provider_factory.h | 34 ++ onnxruntime/core/framework/node_unit.cc | 9 +- onnxruntime/core/framework/utils.cc | 1 + .../core/providers/get_execution_providers.cc | 8 + .../providers/provider_factory_creators.h | 4 + .../builders/impl/activation_op_builder.h | 130 +++++ .../vsinpu/builders/impl/base_op_builder.cc | 205 +++++++ .../vsinpu/builders/impl/base_op_builder.h | 75 +++ .../vsinpu/builders/impl/cast_op_builder.h | 47 ++ .../vsinpu/builders/impl/clip_op_builder.cc | 115 ++++ .../vsinpu/builders/impl/clip_op_builder.h | 57 ++ .../vsinpu/builders/impl/concat_op_builder.h | 65 +++ .../vsinpu/builders/impl/conv_op_builder.h | 162 ++++++ .../builders/impl/dequantize_op_builder.h | 83 +++ .../builders/impl/elementwise_op_builder.h | 98 ++++ .../vsinpu/builders/impl/flatten_op_builder.h | 65 +++ .../vsinpu/builders/impl/gather_op_builder.h | 86 +++ .../vsinpu/builders/impl/gemm_op_builder.h | 148 ++++++ .../vsinpu/builders/impl/matmul_op_builder.h | 56 ++ .../vsinpu/builders/impl/norm_op_builder.h | 86 +++ .../vsinpu/builders/impl/pool_op_builder.h | 152 ++++++ .../builders/impl/qlinear_binary_op_builder.h | 85 +++ .../builders/impl/qlinearconcat_op_builder.h | 48 ++ .../builders/impl/qlinearconv_op_builder.h | 151 ++++++ .../builders/impl/qlinearmatmul_op_builder.h | 83 +++ .../builders/impl/quantize_op_builder.h | 79 +++ .../vsinpu/builders/impl/reduce_op_builder.h | 82 +++ .../vsinpu/builders/impl/resize_op_builder.h | 153 ++++++ .../vsinpu/builders/impl/softmax_op_builder.h | 101 ++++ .../vsinpu/builders/impl/squeeze_op_builder.h | 88 +++ .../vsinpu/builders/impl/tensor_op_builder.h | 142 +++++ .../vsinpu/builders/impl/tile_op_builder.h | 71 +++ .../builders/impl/unsqueeze_op_builder.h | 89 ++++ .../providers/vsinpu/builders/op_builder.h | 48 ++ .../vsinpu/builders/op_builder_factory.h | 133 +++++ .../vsinpu/patches/AccuracyCorrection.patch | 26 + .../patches/local_testing_record_res.patch | 343 ++++++++++++ .../vsinpu/patches/mlas_crosscompiling.patch | 34 ++ .../test_scripts/compare_cosine_sim.py | 29 + .../patches/test_scripts/compare_topn.py | 34 ++ .../patches/test_scripts/result_compare.sh | 23 + onnxruntime/core/providers/vsinpu/symbols.txt | 1 + .../core/providers/vsinpu/vsinpu_ep_graph.cc | 296 +++++++++++ .../core/providers/vsinpu/vsinpu_ep_graph.h | 116 ++++ .../vsinpu/vsinpu_execution_provider.cc | 277 ++++++++++ .../vsinpu/vsinpu_execution_provider.h | 53 ++ .../vsinpu/vsinpu_provider_factory.cc | 59 ++ .../vsinpu/vsinpu_provider_factory_creator.h | 34 ++ .../core/providers/vsinpu/vsinpu_util.cc | 502 ++++++++++++++++++ .../core/providers/vsinpu/vsinpu_util.h | 131 +++++ onnxruntime/test/onnx/TestCase.cc | 2 +- onnxruntime/test/onnx/main.cc | 13 +- .../test/perftest/command_args_parser.cc | 2 + onnxruntime/test/perftest/ort_test_session.cc | 6 + onnxruntime/test/providers/base_tester.cc | 4 + onnxruntime/test/providers/cpu/model_tests.cc | 15 + onnxruntime/test/util/default_providers.cc | 8 + .../test/util/include/default_providers.h | 2 + onnxruntime/test/util/include/providers.h | 3 + tools/ci_build/build.py | 2 + 65 files changed, 5096 insertions(+), 3 deletions(-) create mode 100644 include/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/activation_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.cc create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/concat_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/dequantize_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/flatten_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/gemm_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/matmul_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/norm_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/pool_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/qlinear_binary_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/qlinearconcat_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/qlinearconv_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/qlinearmatmul_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/quantize_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/resize_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/softmax_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/squeeze_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/tensor_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/tile_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/unsqueeze_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h create mode 100644 onnxruntime/core/providers/vsinpu/patches/AccuracyCorrection.patch create mode 100644 onnxruntime/core/providers/vsinpu/patches/local_testing_record_res.patch create mode 100644 onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch create mode 100644 onnxruntime/core/providers/vsinpu/patches/test_scripts/compare_cosine_sim.py create mode 100644 onnxruntime/core/providers/vsinpu/patches/test_scripts/compare_topn.py create mode 100644 onnxruntime/core/providers/vsinpu/patches/test_scripts/result_compare.sh create mode 100644 onnxruntime/core/providers/vsinpu/symbols.txt create mode 100644 onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc create mode 100644 onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h create mode 100644 onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc create mode 100644 onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h create mode 100644 onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.cc create mode 100644 onnxruntime/core/providers/vsinpu/vsinpu_provider_factory_creator.h create mode 100644 onnxruntime/core/providers/vsinpu/vsinpu_util.cc create mode 100644 onnxruntime/core/providers/vsinpu/vsinpu_util.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 4483e4d5cb17f..c4412e0934f17 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -101,6 +101,7 @@ option(onnxruntime_BUILD_OBJC "Build Objective-C library" OFF) option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to provide eigen_SOURCE_PATH if turn this on." OFF) option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) +option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) @@ -797,6 +798,11 @@ if (onnxruntime_USE_RKNPU) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_RKNPU=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES rknpu) endif() +if (onnxruntime_USE_VSINPU) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_VSINPU=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_VSINPU=1) + list(APPEND ONNXRUNTIME_PROVIDER_NAMES vsinpu) +endif() if (onnxruntime_USE_NNAPI_BUILTIN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_NNAPI=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_NNAPI_BUILTIN=1) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 977aa44b0e8d7..ec98047750a91 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -189,6 +189,7 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_SNPE} ${PROVIDERS_TVM} ${PROVIDERS_RKNPU} + ${PROVIDERS_VSINPU} ${PROVIDERS_XNNPACK} ${PROVIDERS_WEBNN} ${PROVIDERS_AZURE} diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 7e7819ac31a19..402135adbdd89 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -80,6 +80,9 @@ endif() if(onnxruntime_USE_RKNPU) set(PROVIDERS_RKNPU onnxruntime_providers_rknpu) endif() +if(onnxruntime_USE_VSINPU) + set(PROVIDERS_VSINPU onnxruntime_providers_vsinpu) +endif() if(onnxruntime_USE_DML) set(PROVIDERS_DML onnxruntime_providers_dml) endif() @@ -188,6 +191,35 @@ if (onnxruntime_USE_TVM) include(onnxruntime_providers_tvm.cmake) endif() +if (onnxruntime_USE_VSINPU) + add_definitions(-DUSE_VSINPU=1) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") + file(GLOB_RECURSE onnxruntime_providers_vsinpu_srcs + "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/builders/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/builders/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/vsinpu/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" + ) + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vsinpu_srcs}) + add_library(onnxruntime_providers_vsinpu ${onnxruntime_providers_vsinpu_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_vsinpu + onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf-lite flatbuffers Boost::mp11 + safeint_interface nsync::nsync_cpp) + add_dependencies(onnxruntime_providers_vsinpu ${onnxruntime_EXTERNAL_DEPENDENCIES}) + set_target_properties(onnxruntime_providers_vsinpu PROPERTIES FOLDER "ONNXRuntime" LINKER_LANGUAGE CXX) + target_include_directories(onnxruntime_providers_vsinpu PRIVATE ${ONNXRUNTIME_ROOT} $ENV{TIM_VX_INSTALL}/include) + + find_library(TIMVX_LIBRARY NAMES tim-vx PATHS $ENV{TIM_VX_INSTALL}/lib NO_DEFAULT_PATH) + if(TIMVX_LIBRARY) + target_link_libraries(onnxruntime_providers_vsinpu PRIVATE ${TIMVX_LIBRARY}) + else() + message(FATAL_ERROR "Cannot find TIM-VX library!") + endif() + +endif() + if (onnxruntime_USE_XNNPACK) include(onnxruntime_providers_xnnpack.cmake) endif() diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index ed71e7a57a500..711a9f77f9094 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -546,6 +546,10 @@ if(onnxruntime_USE_NNAPI_BUILTIN) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_nnapi) endif() +if(onnxruntime_USE_VSINPU) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_vsinpu) +endif() + if(onnxruntime_USE_JSEP) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_js) endif() @@ -589,6 +593,7 @@ set(ONNXRUNTIME_TEST_LIBS ${onnxruntime_libs} # CUDA, ROCM, TENSORRT, MIGRAPHX, DNNL, and OpenVINO are dynamically loaded at runtime ${PROVIDERS_NNAPI} + ${PROVIDERS_VSINPU} ${PROVIDERS_JS} ${PROVIDERS_QNN} ${PROVIDERS_SNPE} diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index c4a46cd422219..39acb6b4f2aa4 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -52,6 +52,7 @@ constexpr const char* kXnnpackExecutionProvider = "XnnpackExecutionProvider"; constexpr const char* kWebNNExecutionProvider = "WebNNExecutionProvider"; constexpr const char* kCannExecutionProvider = "CANNExecutionProvider"; constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; +constexpr const char* kVSINPUExecutionProvider = "VSINPUExecutionProvider"; constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path"; constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point"; diff --git a/include/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.h b/include/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.h new file mode 100644 index 0000000000000..a84067a19aa8a --- /dev/null +++ b/include/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.h @@ -0,0 +1,34 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include "onnxruntime_c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_VSINPU, _In_ OrtSessionOptions* options); + +#ifdef __cplusplus +} +#endif diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index ac4301641105a..e2c06fbdfa621 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -285,7 +285,7 @@ void NodeUnit::InitForSingleNode() { const auto& output_defs = target_node_.OutputDefs(); const auto& node_attrs = target_node_.GetAttributes(); auto qlinear_type = GetQLinearOpType(target_node_); - if (qlinear_type == QLinearOpType::Unknown || IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support + if (qlinear_type == QLinearOpType::Unknown) { // Not a Qlinear op, add all inputs / outputs auto add_all_io = [](std::vector& defs, const ConstPointerContainer>& node_defs) { @@ -351,6 +351,13 @@ void NodeUnit::InitForSingleNode() { NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3 ? input_defs[2] : nullptr, axis}}); + } else if (IsVariadicQLinearOp(qlinear_type)) { + size_t input_num = (input_defs.size() - 2) / 3; + for (size_t i = 0; i < input_num; i++) { + inputs_.push_back(NodeUnitIODef{*input_defs[3 * i + 2], NodeUnitIODef::QuantParam{*input_defs[3 * i + 3], + input_defs[3 * i + 4]}}); + } + outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[0], input_defs[1]}}); } else { ORT_THROW("The QLinear op [", static_cast(qlinear_type), "] is not supported"); } diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 9c282210d2169..9eed0249711f9 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -61,6 +61,7 @@ bool ProviderIsCpuBased(const std::string& provider_type) { provider_type == onnxruntime::kVitisAIExecutionProvider || provider_type == onnxruntime::kOpenVINOExecutionProvider || provider_type == onnxruntime::kNnapiExecutionProvider || + provider_type == onnxruntime::kVSINPUExecutionProvider || provider_type == onnxruntime::kAclExecutionProvider || provider_type == onnxruntime::kArmNNExecutionProvider || provider_type == onnxruntime::kRknpuExecutionProvider || diff --git a/onnxruntime/core/providers/get_execution_providers.cc b/onnxruntime/core/providers/get_execution_providers.cc index b0f510f054a03..61c035bc29ed5 100644 --- a/onnxruntime/core/providers/get_execution_providers.cc +++ b/onnxruntime/core/providers/get_execution_providers.cc @@ -98,6 +98,14 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] = true, #else false, +#endif + }, + { + kVSINPUExecutionProvider, +#ifdef USE_VSINPU + true, +#else + false, #endif }, { diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index 42a58097e1635..47d3f2f793d7c 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -46,6 +46,10 @@ #include "core/providers/nnapi/nnapi_provider_factory_creator.h" #endif +#if defined(USE_VSINPU) +#include "core/providers/vsinpu/vsinpu_provider_factory_creator.h" +#endif + #if defined(USE_JSEP) #include "core/providers/js/js_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/activation_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/activation_op_builder.h new file mode 100644 index 0000000000000..9a59d90365f64 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/activation_op_builder.h @@ -0,0 +1,130 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/shared/utils/utils.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class ReluOpBuilder : public BaseOpBuilder { + public: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Relu Activation."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +class SigmoidOpBuilder : public BaseOpBuilder { + public: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Sigmoid Activation."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +class TanhOpBuilder : public BaseOpBuilder { + public: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Tanh activation."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +class LeakyReluOpBuilder : public BaseOpBuilder { + public: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating LeakyRelu activation."; + const auto& node = node_unit.GetNode(); + NodeAttrHelper helper(node); + auto alpha = helper.Get("alpha", 1.0f); + auto op = + graph_ep->GetGraph()->CreateOperation(alpha); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +class EluOpBuilder : public BaseOpBuilder { + public: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Elu activation."; + const auto& node = node_unit.GetNode(); + NodeAttrHelper helper(node); + auto alpha = helper.Get("alpha", 1.0f); + auto op = + graph_ep->GetGraph()->CreateOperation(alpha); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +class HardSigmoidOpBuilder : public BaseOpBuilder { + public: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating HardSigmoid activation."; + const auto& node = node_unit.GetNode(); + NodeAttrHelper helper(node); + auto alpha = helper.Get("alpha", 1.0f); + auto beta = helper.Get("beta", 1.0f); + auto op = graph_ep->GetGraph()->CreateOperation( + alpha, beta); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.cc new file mode 100644 index 0000000000000..894bf8e4444f8 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.cc @@ -0,0 +1,205 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +bool BaseOpBuilder::IsSupported(const onnxruntime::GraphViewer& graph_viewer, + const NodeUnit& node_unit) const { + auto initializers = graph_viewer.GetAllInitializedTensors(); + if (!HasSupportedOpSet(node_unit)) { + return false; + } + if (!HasSupportedInputOutputs(initializers, node_unit)) { + return false; + } + return IsOpSupported(graph_viewer, &node_unit.GetNode()); +} + +bool BaseOpBuilder::HasSupportedInputOutputs(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const { + // We do not support unknown(null) input shape + auto has_supported_shape = [](const NodeArg& node_arg, const std::string& name, const std::string& op_type) { + const auto* shape_proto = node_arg.Shape(); + if (!shape_proto) { + LOGS_DEFAULT(WARNING) << "Node [" << name << "] type [" << op_type + << "] Input [" << node_arg.Name() << "] has no shape"; + return false; + } + + // We do not support dynamic shape input yet, but resize op's second input can be empty + for (const auto& dim : shape_proto->dim()) { + if (!dim.has_dim_value()) { + LOGS_DEFAULT(WARNING) << "Dynamic shape is not supported for now, for input:" << node_arg.Name(); + return false; + } + if (dim.dim_value() == 0 && op_type != "Resize") { + LOGS_DEFAULT(WARNING) << "Zero in shape is not supported for now, for input:" << node_arg.Name(); + return false; + } + } + return true; + }; + + auto has_initialized_quant_param = [](const NodeArg& arg, const InitializedTensorSet& initializers) { + auto it = initializers.find(arg.Name()); + if (it == initializers.end()) { + LOGS_DEFAULT(WARNING) << "The quantization param must be an initializer tensor"; + return false; + } + return true; + }; + + for (const auto& input : node_unit.Inputs()) { + if (!input.node_arg.Exists()) { + continue; + } + if (!has_supported_shape(input.node_arg, node_unit.Name(), node_unit.OpType())) + return false; + + if (input.quant_param.has_value()) { + if (!has_supported_shape(input.quant_param->scale, node_unit.Name(), node_unit.OpType())) + return false; + + if (!has_initialized_quant_param(input.quant_param->scale, initializers)) + return false; + // zero point is optional + if (input.quant_param->zero_point) { + if (!has_supported_shape(*input.quant_param->zero_point, node_unit.Name(), node_unit.OpType())) + return false; + if (!has_initialized_quant_param(*input.quant_param->zero_point, initializers)) + return false; + if (input.quant_param->zero_point->Type() != input.node_arg.Type()) { + LOGS_DEFAULT(ERROR) << "Invalid input type because the data type mismatch with its' quant param type."; + return false; + } + } + } + } + for (const auto& output : node_unit.Outputs()) { + if (output.quant_param.has_value()) { + if (!has_supported_shape(output.quant_param->scale, node_unit.Name(), node_unit.OpType())) + return false; + + if (!has_initialized_quant_param(output.quant_param->scale, initializers)) + return false; + // zero point is optional + if (output.quant_param->zero_point) { + if (!has_supported_shape(*output.quant_param->zero_point, node_unit.Name(), node_unit.OpType())) + return false; + if (!has_initialized_quant_param(*output.quant_param->zero_point, initializers)) + return false; + } + } + } + return HasSupportedInputOutputsImpl(initializers, node_unit); +} + +bool BaseOpBuilder::HasSupportedInputOutputsImpl( + const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit) const { + // Check input/output data type, int64 is generally unsupported + // specific op builder can override this if the int64 input corresponds to VSINPU param + for (const auto& input : node_unit.Inputs()) { + auto input_type = input.node_arg.Type(); + if (*input_type == "tensor(int64)" || !util::IsTypeSupported(&input.node_arg)) { + LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported input type : " + << *input_type; + return false; + } + } + for (const auto& output : node_unit.Outputs()) { + auto output_type = output.node_arg.Type(); + if (*output_type == "tensor(int64)" || !util::IsTypeSupported(&output.node_arg)) { + LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported output type : " + << *output_type; + return false; + } + } + return true; +} + +bool BaseOpBuilder::HasSupportedOpSet(const NodeUnit& node_unit) const { + auto since_version = node_unit.SinceVersion(); + if (since_version < GetMinSupportedOpSet(node_unit) || since_version > GetMaxSupportedOpSet(node_unit)) { + LOGS_DEFAULT(VERBOSE) << node_unit.OpType() << " opset [" << since_version + << "] is only supported for opset [" + << GetMinSupportedOpSet(node_unit) << ", " + << GetMaxSupportedOpSet(node_unit) << "]"; + return false; + } + + return true; +} + +bool BaseOpBuilder::BuildOp(vsi::npu::GraphEP* graph_ep, + const onnxruntime::GraphViewer& graph_viewer, + const NodeUnit& node_unit) { + std::vector> inputs; + std::vector input_defs = node_unit.Inputs(); + std::vector output_defs = node_unit.Outputs(); + + for (const auto input_def : input_defs) { + auto it = std::find_if( + graph_ep->GetGraphInputs().begin(), graph_ep->GetGraphInputs().end(), + [input_def](const std::shared_ptr& info) { + return info->name == input_def.node_arg.Name(); + }); + tim::vx::TensorAttribute attr; + if (graph_viewer.IsConstantInitializer(input_def.node_arg.Name(), true)) { + attr = tim::vx::TensorAttribute::CONSTANT; + } else if (it == graph_ep->GetGraphInputs().end()) { + attr = tim::vx::TensorAttribute::TRANSIENT; + } else { + attr = tim::vx::TensorAttribute::INPUT; + } + + auto tensor = graph_ep->MapTIMVXTensor(graph_ep->GetGraph(), input_def, node_unit, + &graph_viewer, attr); + inputs.push_back(tensor); + } + + std::vector> outputs; + + for (auto output_def : output_defs) { + auto it = std::find_if( + graph_ep->GetGraphOutputs().begin(), graph_ep->GetGraphOutputs().end(), + [output_def](const std::shared_ptr& info) { + return info->name == output_def.node_arg.Name(); + }); + tim::vx::TensorAttribute attribute = + it == graph_ep->GetGraphOutputs().end() + ? tim::vx::TensorAttribute::TRANSIENT + : tim::vx::TensorAttribute::OUTPUT; + auto tensor = graph_ep->MapTIMVXTensor(graph_ep->GetGraph(), output_def, node_unit, + &graph_viewer, attribute); + outputs.push_back(tensor); + } + return HandleBuildOp(graph_ep, inputs, outputs, node_unit); +} +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h new file mode 100644 index 0000000000000..c0cf3365f46e3 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h @@ -0,0 +1,75 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include "core/providers/vsinpu/builders/op_builder.h" +#include "core/providers/vsinpu/vsinpu_ep_graph.h" +#include "core/providers/vsinpu/vsinpu_util.h" +#include "tim/vx/operation.h" +#include "tim/vx/ops.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class BaseOpBuilder : public IOpBuilder { + public: + virtual ~BaseOpBuilder() = default; + + bool IsSupported(const onnxruntime::GraphViewer& graph_viewer, + const NodeUnit& node_unit) const override; + bool BuildOp(vsi::npu::GraphEP* graph_ep, + const onnxruntime::GraphViewer& graph_viewer, const NodeUnit& node_unit); + virtual bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const { + return true; + } + + virtual bool IsQuantizedOp(const NodeUnit& /* node_unit */) const { return false; } + + virtual int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const { return 1; } + virtual int GetMaxSupportedOpSet(const NodeUnit& /* node_unit */) const { return 22; } + + virtual bool HasSupportedInputOutputsImpl( + const InitializedTensorSet& initializers, const NodeUnit& node_unit) const; + + // TODO(cfy): Check if this node_unit's type is supported + virtual bool IsNodeUnitTypeSupported(const NodeUnit& node_unit) const { return true; } + + virtual bool HandleBuildOp( + vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) { + return true; + } + + private: + bool HasSupportedOpSet(const NodeUnit& node_unit) const; + bool HasSupportedInputOutputs(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const; +}; +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h new file mode 100644 index 0000000000000..6579f0ca9045f --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h @@ -0,0 +1,47 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/shared/utils/utils.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +namespace onnxruntime { +namespace vsi { +namespace npu { +class CastOpBuilder : public BaseOpBuilder { + protected: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, std::vector>& inputs, + std::vector>& outputs, const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Cast Op."; + NodeAttrHelper helper(node_unit.GetNode()); + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc new file mode 100644 index 0000000000000..85096d0e262d7 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc @@ -0,0 +1,115 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include "core/providers/vsinpu/builders/impl/clip_op_builder.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { + +namespace clip_internal { +template +struct LowMax { + constexpr static T low() { + return std::numeric_limits::lowest(); + } + constexpr static T max() { + return std::numeric_limits::max(); + } +}; +} // namespace clip_internal + +template +struct ClipOpBuilder::ClipImpl { + ClipImpl(vsi::npu::GraphEP* graph_ep, std::vector>& inputs, + std::vector>& outputs) { + T min_default = clip_internal::LowMax::low(); + T max_default = clip_internal::LowMax::max(); + + T* min_data = &min_default; + T* max_data = &max_default; + std::shared_ptr min_tensor = nullptr; + std::shared_ptr max_tensor = nullptr; + if (inputs.size() > 1) { + min_tensor = inputs[1]; + if (inputs.size() > 2) { + max_tensor = inputs[2]; + } + } + if (min_tensor) { + min_tensor->CopyDataFromTensor(min_data); + } + if (max_tensor) { + max_tensor->CopyDataFromTensor(max_data); + } + auto op = graph_ep->GetGraph()->CreateOperation( + static_cast(*min_data), static_cast(*max_data)); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + } +}; + +bool ClipOpBuilder::HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) { + LOGS_DEFAULT(INFO) << "Creating Clip Op."; + if (node_unit.SinceVersion() <= 6) { + NodeAttrHelper helper(node_unit.GetNode()); + auto min = helper.Get("min", -3.402e+38f); + auto max = helper.Get("max", 3.402e+38f); + auto op = graph_ep->GetGraph()->CreateOperation(min, max); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + } else { + switch (inputs[0]->GetDataType()) { + case tim::vx::DataType::INT8: + ClipImpl(graph_ep, inputs, outputs); + break; + case tim::vx::DataType::UINT8: + ClipImpl(graph_ep, inputs, outputs); + break; + case tim::vx::DataType::INT16: + ClipImpl(graph_ep, inputs, outputs); + break; + case tim::vx::DataType::INT32: + ClipImpl(graph_ep, inputs, outputs); + break; + case tim::vx::DataType::FLOAT16: + ClipImpl(graph_ep, inputs, outputs); + break; + case tim::vx::DataType::FLOAT32: + default: + ClipImpl(graph_ep, inputs, outputs); + break; + } + } + return true; +} + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.h new file mode 100644 index 0000000000000..368cb092657c8 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.h @@ -0,0 +1,57 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class ClipOpBuilder final : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + if (node->SinceVersion() > 6) { + if (node->InputDefs().size() > 1 && + !Contains(graph_viewer.GetAllInitializedTensors(), node->InputDefs()[1]->Name())) { + LOGS_DEFAULT(WARNING) << "Min/Max value must be const input or attribute."; + return false; + } + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override; + + private: + template + struct ClipImpl; +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/concat_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/concat_op_builder.h new file mode 100644 index 0000000000000..4d3fc658b7bef --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/concat_op_builder.h @@ -0,0 +1,65 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class ConcatOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + NodeAttrHelper helper(*node); + auto axis = helper.Get("axis", 0); + auto input_defs = node->InputDefs(); + auto input_shape = vsi::npu::util::GetTensorShape(*input_defs[0]); + int32_t rank = input_shape.NumDimensions(); + if (axis >= rank || axis < -rank) { + LOGS_DEFAULT(ERROR) << "Axis is invalid in Concat."; + return false; + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Concat Op."; + NodeAttrHelper helper(node_unit.GetNode()); + auto axis = helper.Get("axis", 0); + axis = util::ReverseAxis(axis, inputs[0]->GetShape().size()); + auto op = graph_ep->GetGraph()->CreateOperation(static_cast(axis), inputs.size()); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h new file mode 100644 index 0000000000000..d44e1ce1799c1 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h @@ -0,0 +1,162 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include +#include "core/providers/shared/utils/utils.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +namespace onnxruntime { +namespace vsi { +namespace npu { +class ConvOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + auto shape = vsi::npu::util::GetTensorShape(*input_defs[0]); + if (shape.NumDimensions() == 5) { + LOGS_DEFAULT(WARNING) << "Not support conv3d yet."; + return false; + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + auto input_tensor = inputs[0]; + auto weight_tensor = inputs[1]; + auto OChannel_idx = weight_tensor->GetShape().size() - 1; + const bool is_1d_conv = + weight_tensor->GetShape().size() == 3 ? true : false; + NodeAttrHelper helper(node_unit.GetNode()); + auto padtype = helper.Get("auto_pad", std::string("")); + auto group = helper.Get("group", static_cast(1)); + + std::string op_type = (group != 1 && group == weight_tensor->GetShape()[OChannel_idx]) + ? "DepthwiseConv" + : (group != 1) ? "GroupConv" + : "Conv"; + op_type += is_1d_conv ? "1D" : "2D"; + std::string op_name = std::string("Creating ") + op_type + " Op"; + LOGS_DEFAULT(INFO) << op_name; + + uint32_t default_uint = 1; + std::vector default_vec = {1, 1}; + + auto stride = + helper.Get("strides", is_1d_conv ? std::vector{default_uint} + : default_vec); + auto dilation = + helper.Get("dilations", is_1d_conv ? std::vector{default_uint} + : default_vec); + + std::shared_ptr op; + if (padtype != "NOTSET") { // array "pads" is not set + if (group != 1 && group != weight_tensor->GetShape()[OChannel_idx]) { + if (is_1d_conv) { + op = graph_ep->GetGraph() + ->CreateOperation( + vsi::npu::util::GetPadType(padtype), stride[0], + dilation[0], group, tim::vx::DataLayout::WCN, + tim::vx::DataLayout::WIcOc); + } else { + op = graph_ep->GetGraph() + ->CreateOperation( + vsi::npu::util::GetPadType(padtype), + /* W_stride, H_stride*/ + std::array{stride[1], stride[0]}, + /* W_dilation, H_dilation*/ + std::array{dilation[1], dilation[0]}, group, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + } + } else { + int32_t multiplier = group == 1 + ? 0 + : weight_tensor->GetShape()[OChannel_idx] / input_tensor->GetShape()[OChannel_idx - 1]; + if (is_1d_conv) { + op = graph_ep->GetGraph()->CreateOperation( + vsi::npu::util::GetPadType(padtype), stride[0], dilation[0], multiplier, + tim::vx::DataLayout::WCN, tim::vx::DataLayout::WIcOc); + } else { + op = graph_ep->GetGraph()->CreateOperation( + vsi::npu::util::GetPadType(padtype), + /* W_stride, H_stride*/ + std::array{stride[1], stride[0]}, + /* W_dilation, H_dilation*/ + std::array{dilation[1], dilation[0]}, multiplier, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + } + } + } else { + auto pads = helper.Get("pads", std::vector{0U, 0U}); + if (group != 1 && group != weight_tensor->GetShape()[OChannel_idx]) { + if (is_1d_conv) { + op = graph_ep->GetGraph() + ->CreateOperation( + vsi::npu::util::GetPadType(padtype), + std::array{pads[0], pads[1]}, stride[0], + dilation[0], group, tim::vx::DataLayout::WCN, + tim::vx::DataLayout::WIcOc); + } else { + op = graph_ep->GetGraph() + ->CreateOperation( + /* W_begin,W_end, H_begin,H_end*/ std::array< + uint32_t, 4>{pads[1], pads[3], pads[0], pads[2]}, + /* W_stride, H_stide*/ + std::array{stride[1], stride[0]}, + /* W_dilation, H_dilation*/ + std::array{dilation[1], dilation[0]}, group, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + } + } else { + int32_t multiplier = group == 1 + ? 0 + : weight_tensor->GetShape()[OChannel_idx] / input_tensor->GetShape()[OChannel_idx - 1]; + if (is_1d_conv) { + op = graph_ep->GetGraph()->CreateOperation( + std::array{pads[0], pads[1]}, stride[0], dilation[0], + multiplier, tim::vx::DataLayout::WCN, tim::vx::DataLayout::WIcOc); + } else { + op = graph_ep->GetGraph()->CreateOperation( + /* W_begin,W_end, H_begin,H_end*/ std::array{pads[1], pads[3], + pads[0], pads[2]}, + /* W_stride, H_stride*/ + std::array{stride[1], stride[0]}, + /* W_dilation, H_dilation*/ + std::array{dilation[1], dilation[0]}, multiplier, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + } + } + } + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/dequantize_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/dequantize_op_builder.h new file mode 100644 index 0000000000000..50b295f2fb539 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/dequantize_op_builder.h @@ -0,0 +1,83 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class DequantizeLinearOpBuilder : public BaseOpBuilder { + enum DequantizeINPUTS { + input_tensor = 0, + scale_tensor = 1, + zero_point_tensor = 2 + }; + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + auto input_type = node_unit.Inputs()[0].node_arg.Type(); + if (*input_type == "tensor(int64)" || !util::IsTypeSupported(&node_unit.Inputs()[0].node_arg)) { + LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported input type : " + << *input_type; + return false; + } + if (!node_unit.Inputs()[0].quant_param.has_value()) { + LOGS_DEFAULT(WARNING) << "The quantization params must be known."; + return false; + } + if (node_unit.Inputs()[0].quant_param->scale.Shape()->dim_size() != 0 && + node_unit.Inputs()[0].quant_param->scale.Shape()->dim(0).dim_value() != 1) { + LOGS_DEFAULT(WARNING) << "Per channel quantized input is not support in DequantizeLinear op."; + return false; + } + return true; + } + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + NodeAttrHelper helper(*node); + if (helper.HasAttr("block_size") && helper.Get("block_size", 0) != 0) { + LOGS_DEFAULT(WARNING) << "Not support block quantization yet."; + return false; + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating Dequantize Op."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h new file mode 100644 index 0000000000000..89809a4513340 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h @@ -0,0 +1,98 @@ + +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +namespace onnxruntime { +namespace vsi { +namespace npu { +#define ELEMENTWISE_OP_BUILDER(onnx_op_type, vsinpu_op_kind) \ + class onnx_op_type##OpBuilder : public BaseOpBuilder { \ + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, \ + const Node* node) const override { \ + for (auto input : node->InputDefs()) { \ + if (*input->Type() == "tensor(int64)") { \ + LOGS_DEFAULT(WARNING) << "Int64 type is not suppoted as elementwise operation input."; \ + return false; \ + } \ + } \ + return true; \ + } \ + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, \ + std::vector>& inputs, \ + std::vector>& outputs, \ + const NodeUnit& node_unit) override { \ + LOGS_DEFAULT(INFO) << "Creating " << #onnx_op_type << " Op"; \ + auto op = graph_ep->GetGraph() -> CreateOperation(); \ + (*op).BindInputs(inputs).BindOutputs(outputs); \ + return true; \ + ; \ + } \ + }; + +ELEMENTWISE_OP_BUILDER(Add, Add); +ELEMENTWISE_OP_BUILDER(Sub, Sub); +ELEMENTWISE_OP_BUILDER(Mul, Multiply); +ELEMENTWISE_OP_BUILDER(Div, Div); // not consider zero +ELEMENTWISE_OP_BUILDER(Abs, Abs); +ELEMENTWISE_OP_BUILDER(Sqrt, Sqrt); +ELEMENTWISE_OP_BUILDER(Exp, Exp); +ELEMENTWISE_OP_BUILDER(Floor, Floor); +ELEMENTWISE_OP_BUILDER(Log, Log); +ELEMENTWISE_OP_BUILDER(Sin, Sin); +ELEMENTWISE_OP_BUILDER(HardSwish, HardSwish); + +class PowOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input0_type = *node->InputDefs()[0]->Type(); + auto input1_type = *node->InputDefs()[1]->Type(); + if (input0_type != input1_type) { + if ((input0_type == "tensor(float)" && input1_type == "tensor(int32)") || + (input0_type == "tensor(int32)" && input1_type == "tensor(float)")) { + LOGS_DEFAULT(WARNING) << "Pow op does not support one of input is float32 while the other one is int32 type."; + return false; + } + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating Pow Op"; + auto op = graph_ep->GetGraph() + ->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/flatten_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/flatten_op_builder.h new file mode 100644 index 0000000000000..dfb0bb9c1b99f --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/flatten_op_builder.h @@ -0,0 +1,65 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class FlattenOpBuilder : public BaseOpBuilder { + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Flatten Op."; + std::vector reshape_param; + if (outputs[0]->GetShape().size() == 2) { + reshape_param = outputs[0]->GetShape(); + } else { + auto input_shape = inputs[0]->GetShape(); + NodeAttrHelper helper(node_unit.GetNode()); + int64_t axis = helper.Get("axis", 1); + axis = util::ReverseAxis(static_cast(axis), input_shape.size()); + uint32_t first_dim = 1; + for (int64_t i = 0; i < axis; i++) { + first_dim *= inputs[0]->GetShape()[i]; + } + uint32_t second_dim = inputs[0]->GetSpec().GetElementNum() / first_dim; + reshape_param.push_back(first_dim); + reshape_param.push_back(second_dim); + } + auto op = graph_ep->GetGraph()->CreateOperation(reshape_param); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h new file mode 100644 index 0000000000000..0325b68ae0ad7 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h @@ -0,0 +1,86 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class GatherOpBuilder : public BaseOpBuilder { + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + auto input = node_unit.Inputs()[0]; + auto indices = node_unit.Inputs()[1]; + if (util::IsTypeSupported(&input.node_arg) && util::IsTypeSupported(&indices.node_arg)) { + if (*input.node_arg.Type() == "tensor(int64)") { + LOGS_DEFAULT(WARNING) << "Only support indices tensor to be int64 type in gather op."; + return false; + } + if (*indices.node_arg.Type() != "tensor(int64)" && *indices.node_arg.Type() != "tensor(int32)") { + LOGS_DEFAULT(WARNING) << "Unsupported indices tensor type in gather op."; + return false; + } + if (*indices.node_arg.Type() == "tensor(int64)" && !Contains(initializers, indices.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Only support const attribute if indice tensor is in int64 type."; + return false; + } + return true; + } + return false; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Gather Op."; + NodeAttrHelper helper(node_unit.GetNode()); + auto axis = helper.Get("axis", 0); + axis = util::ReverseAxis(axis, inputs[0]->GetShape().size()); + auto op = graph_ep->GetGraph()->CreateOperation(axis, 0); + + bool is_i64_indices = inputs[1]->GetDataType() == tim::vx::DataType::INT64; + if (!is_i64_indices) { + (*op).BindInputs(inputs).BindOutputs(outputs); + } else { + std::vector origin_data(inputs[1]->GetSpec().GetElementNum()); + inputs[1]->CopyDataFromTensor(origin_data.data()); + std::vector transformed_data(origin_data.begin(), origin_data.end()); + tim::vx::TensorSpec ts = inputs[1]->GetSpec().SetAttribute(tim::vx::TensorAttribute::INPUT); + ts.SetDataType(tim::vx::DataType::INT32); + auto transformed_indices = graph_ep->GetGraph()->CreateTensor(ts, transformed_data.data()); + (*op).BindInput(inputs[0]).BindInput(transformed_indices).BindOutput(outputs[0]); + } + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/gemm_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/gemm_op_builder.h new file mode 100644 index 0000000000000..6f2c590b862b6 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/gemm_op_builder.h @@ -0,0 +1,148 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/shared/utils/utils.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +namespace onnxruntime { +namespace vsi { +namespace npu { +class GemmOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + NodeAttrHelper helper(*node); + auto weight_units = helper.Get("transB", 0) == 1 + ? vsi::npu::util::GetTensorShape(*input_defs[1]).GetDims()[0] + : vsi::npu::util::GetTensorShape(*input_defs[1]).GetDims()[1]; + if (input_defs.size() > 2) { + auto bias_shape = vsi::npu::util::GetTensorShape(*input_defs[2]); + if (bias_shape.NumDimensions() == 1 && bias_shape.GetDims()[0] != weight_units) { + LOGS_DEFAULT(WARNING) << "Not support to broadcast bias shape."; + return false; + } else if (bias_shape.NumDimensions() == 2 && + (bias_shape.Size() != weight_units || + (bias_shape.GetDims()[0] != 1 && bias_shape.GetDims()[1] != 1))) { + LOGS_DEFAULT(WARNING) << "Not support 2-dims bias shape."; + return false; + } + + if (*input_defs[2]->Type() == "tensor(float16)" && + !graph_viewer.IsConstantInitializer(input_defs[2]->Name(), true)) { + LOGS_DEFAULT(WARNING) << "Not support f16 bias with input attr."; + return false; + } + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Gemm Op."; + auto input_A = inputs[0]; + auto input_B = inputs[1]; + NodeAttrHelper helper(node_unit.GetNode()); + + auto trans_A = helper.Get("transA", 0); + auto trans_B = helper.Get("transB", 0); + const bool has_alpha = (helper.Get("alpha", 1.0f) != 1.0); + const bool has_beta = (helper.Get("beta", 1.0f) != 1.0); + const bool has_C = (inputs.size() == 3); + auto weight_units = helper.Get("transB", 0) == 1 ? inputs[1]->GetShape()[1] : inputs[1]->GetShape()[0]; + + tim::vx::TensorSpec coef_spec(tim::vx::DataType::FLOAT32, {1}, + tim::vx::TensorAttribute::CONSTANT); + + auto multiply_impl = [&](std::shared_ptr input, + std::shared_ptr coef, + std::shared_ptr output) { + auto multiply_op = graph_ep->GetGraph()->CreateOperation(); + (*multiply_op).BindInput(input).BindInput(coef).BindOutput(output); + graph_ep->GetOps().push_back(multiply_op); + }; + + auto transpose_impl = [&](std::shared_ptr input, + std::shared_ptr output) { + std::vector perm = {1U, 0U}; + auto transpose_op = graph_ep->GetGraph()->CreateOperation(perm); + (*transpose_op).BindInput(input).BindOutput(output); + graph_ep->GetOps().push_back(std::move(transpose_op)); + }; + + auto fc_impl = [&](std::vector> inputs, + std::shared_ptr output) { + auto fc_op = graph_ep->GetGraph()->CreateOperation(0, weight_units); + (*fc_op).BindInputs(inputs).BindOutput(output); + graph_ep->GetOps().push_back(std::move(fc_op)); + }; + + auto alpha_A = input_A; + std::shared_ptr beta_C; + auto final_A = input_A; + auto final_B = input_B; + + if (has_alpha) { + auto alpha_tensor = graph_ep->GetGraph()->CreateTensor(coef_spec); + auto alpha = helper.Get("alpha", 1.0f); + alpha_tensor->CopyDataToTensor(&alpha); + alpha_A = graph_ep->GetGraph()->CreateTensor( + input_A->GetSpec().AsTransientSpec()); + multiply_impl(input_A, alpha_tensor, alpha_A); + final_A = alpha_A; + } + if (has_beta) { + auto beta_tensor = graph_ep->GetGraph()->CreateTensor(coef_spec); + auto beta = helper.Get("beta", 1.0f); + beta_tensor->CopyDataToTensor(&beta); + beta_C = graph_ep->GetGraph()->CreateTensor( + inputs[2]->GetSpec().AsTransientSpec()); + multiply_impl(inputs[2], beta_tensor, beta_C); + } else if (has_C) { + beta_C = inputs[2]; + } + + if (trans_A) { + final_A = graph_ep->GetGraph()->CreateTensor( + input_A->GetSpec().AsTransientSpec()); + transpose_impl(alpha_A, final_A); + } + if (!trans_B) { + final_B = graph_ep->GetGraph()->CreateTensor( + input_B->GetSpec().AsTransientSpec()); + transpose_impl(input_B, final_B); + } + std::vector> fc_inputs = {final_A, final_B}; + + if (has_C) fc_inputs.push_back(beta_C); + fc_impl(fc_inputs, outputs[0]); + + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/matmul_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/matmul_op_builder.h new file mode 100644 index 0000000000000..8cdf72906b644 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/matmul_op_builder.h @@ -0,0 +1,56 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class MatMulOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto output_defs = node->OutputDefs(); + if (output_defs[0]->Shape()->dim_size() == 0) { + LOGS_DEFAULT(WARNING) << "Inner product of 1-D tensor is not supported in MatMul op."; + return false; + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Matmul Op."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/norm_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/norm_op_builder.h new file mode 100644 index 0000000000000..997163c6b1a6d --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/norm_op_builder.h @@ -0,0 +1,86 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { + +class BatchNormOpBuilder : public BaseOpBuilder { + enum NormINPUTS { + input_tensor = 0, + scale_tensor = 1, + Bias_tensor = 2, + mean_tensor = 3, + var_tensor = 4 + }; + int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 9; } + + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + NodeAttrHelper helper(*node); + auto training_mode = helper.Get("training_mode", 0); + if (training_mode) { + LOGS_DEFAULT(WARNING) << "Training is not supported in batch_norm op."; + return false; + } + if (helper.HasAttr("spatial")) { + LOGS_DEFAULT(WARNING) << "VSINPU does not support 'spatial' parameter."; + return false; + } + if (!graph_viewer.IsConstantInitializer(input_defs[NormINPUTS::scale_tensor]->Name(), true)) { + LOGS_DEFAULT(WARNING) << "Not support mean/var/gamma/beta set as dynamic input yet."; + return false; + } + + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating BatchNorm Op."; + NodeAttrHelper helper(node_unit.GetNode()); + auto epsilon = helper.Get("epsilon", 1e-5f); + auto op = graph_ep->GetGraph()->CreateOperation(epsilon); + std::vector> reordered_inputs; + int indices[] = {NormINPUTS::input_tensor, NormINPUTS::mean_tensor, NormINPUTS::var_tensor, + NormINPUTS::scale_tensor, NormINPUTS::Bias_tensor}; + for (int i : indices) { + reordered_inputs.push_back(inputs[i]); + } + (*op).BindInputs(reordered_inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/pool_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/pool_op_builder.h new file mode 100644 index 0000000000000..7cfa9faf68480 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/pool_op_builder.h @@ -0,0 +1,152 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class BasePoolOpBuilder : public BaseOpBuilder { + public: + explicit BasePoolOpBuilder(tim::vx::PoolType pool_type) : pool_type_(pool_type) {} + + protected: + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, const Node* node) const override { + auto shape = vsi::npu::util::GetTensorShape(*node->InputDefs()[0]); + if (shape.NumDimensions() == 5) { + LOGS_DEFAULT(WARNING) << "3DPool is not supported yet."; + return false; + } + + NodeAttrHelper helper(*node); + if (helper.HasAttr("dilations")) { + LOGS_DEFAULT(WARNING) << "NonMaxPool with Dilation parameter is not supported."; + return false; + } + return true; + } + bool CreatePoolingOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const std::array& kernel_size, + const std::array& strides, + const std::array& pads, + bool is_global, + const tim::vx::RoundType ceil_mode) { + const bool is_1d_pool = inputs[0]->GetShape().size() == 3; + std::shared_ptr op; + + // Create the appropriate pooling operation + if (is_global) { + if (is_1d_pool) { + op = graph_ep->GetGraph()->CreateOperation(pool_type_, inputs[0]->GetShape()[0], + ceil_mode); + } else { + std::array input_size = {inputs[0]->GetShape()[0], inputs[0]->GetShape()[1]}; + op = graph_ep->GetGraph()->CreateOperation(pool_type_, input_size, ceil_mode); + } + + } else { + if (is_1d_pool) { + std::array arr = {pads[2], pads[0]}; + op = graph_ep->GetGraph()->CreateOperation(pool_type_, arr, + kernel_size[1], strides[1], ceil_mode); + } else { + op = graph_ep->GetGraph()->CreateOperation(pool_type_, pads, kernel_size, + strides, ceil_mode); + } + } + + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } + tim::vx::PoolType pool_type_; +}; + +class TraditionalPoolOpBuilder : public BasePoolOpBuilder { + public: + TraditionalPoolOpBuilder() : BasePoolOpBuilder(tim::vx::PoolType::MAX) {} + + protected: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + NodeAttrHelper helper(node_unit.GetNode()); + auto ksize = helper.Get("kernel_shape", std::vector{1U, 1U}); + auto strides = helper.Get("strides", std::vector{1U, 1U}); + auto pads = helper.Get("pads", std::vector{0U, 0U, 0U, 0U}); + tim::vx::RoundType ceil_mode = helper.Get("ceil_mode", 0U) == 0 + ? tim::vx::RoundType::FLOOR + : tim::vx::RoundType::CEILING; + return CreatePoolingOp(graph_ep, inputs, outputs, {ksize[1], ksize[0]}, {strides[1], strides[0]}, + {pads[1], pads[3], pads[0], pads[2]}, false, ceil_mode); + } +}; + +class GlobalPoolOpBuilder : public BasePoolOpBuilder { + public: + GlobalPoolOpBuilder() : BasePoolOpBuilder(tim::vx::PoolType::MAX) {} + + protected: + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + NodeAttrHelper helper(node_unit.GetNode()); + tim::vx::RoundType ceil_mode = helper.Get("ceil_mode", 0U) == 0 + ? tim::vx::RoundType::FLOOR + : tim::vx::RoundType::CEILING; + return CreatePoolingOp(graph_ep, inputs, outputs, {}, {}, {}, true, ceil_mode); + } +}; + +class GlobalAveragePoolOpBuilder : public GlobalPoolOpBuilder { + public: + GlobalAveragePoolOpBuilder() { pool_type_ = tim::vx::PoolType::AVG; } +}; + +class GlobalMaxPoolOpBuilder : public GlobalPoolOpBuilder { + public: + GlobalMaxPoolOpBuilder() { pool_type_ = tim::vx::PoolType::MAX; } +}; + +class AveragePoolOpBuilder : public TraditionalPoolOpBuilder { + public: + AveragePoolOpBuilder() { pool_type_ = tim::vx::PoolType::AVG; } +}; + +class MaxPoolOpBuilder : public TraditionalPoolOpBuilder { + public: + MaxPoolOpBuilder() { pool_type_ = tim::vx::PoolType::MAX; } +}; + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/qlinear_binary_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/qlinear_binary_op_builder.h new file mode 100644 index 0000000000000..def37b1ec1019 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/qlinear_binary_op_builder.h @@ -0,0 +1,85 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class BaseQLinearOpBuilder : public BaseOpBuilder { + enum { + INPUT_A = 0, + INPUT_A_SCALE = 1, + INPUT_A_ZP = 2, + INPUT_B = 3, + INPUT_B_SCALE = 4, + INPUT_B_ZP = 5, + OUTPUT_SCALE = 6, + OUTPUT_ZP = 7, + }; + + protected: + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, const Node* node) const override { + for (int i = 0; i < node->InputDefs().size(); i++) { + if (i == INPUT_A || i == INPUT_B) continue; + if (!graph_viewer.IsConstantInitializer(node->InputDefs()[i]->Name(), true)) { + LOGS_DEFAULT(WARNING) << "Only support const scale / zero point."; + return false; + } + } + return true; + } +}; + +class QLinearAddOpBuilder : public BaseQLinearOpBuilder { + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating QLinearAdd Op."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +class QLinearMulOpBuilder : public BaseQLinearOpBuilder { + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating QLinearMul Op."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconcat_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconcat_op_builder.h new file mode 100644 index 0000000000000..dc51e99730c15 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconcat_op_builder.h @@ -0,0 +1,48 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/shared/utils/utils.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +namespace onnxruntime { +namespace vsi { +namespace npu { +class QLinearConcatOpBuilder : public BaseOpBuilder { + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, std::vector>& inputs, + std::vector>& outputs, const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating QLinearConcat Op."; + NodeAttrHelper helper(node_unit.GetNode()); + int axis = helper.Get("axis", 0); + axis = util::ReverseAxis(axis, inputs[0]->GetShape().size()); + auto op = graph_ep->GetGraph()->CreateOperation(axis, inputs.size()); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconv_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconv_op_builder.h new file mode 100644 index 0000000000000..8b63a07e17f1d --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearconv_op_builder.h @@ -0,0 +1,151 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include +#include "core/providers/shared/utils/utils.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/framework/tensorprotoutils.h" +namespace onnxruntime { +namespace vsi { +namespace npu { +class QLinearConvOpBuilder : public BaseOpBuilder { + enum QLinearConvINPUTS { + INPUT_TENSOR = 0, + INPUT_TENSOR_SCALE = 1, + INPUT_TENSOR_ZP = 2, + WEIGHT_TENSOR = 3, + WEIGHT_TENSOR_SCALE = 4, + WEIGHT_TENSOR_ZP = 5, + OUTPUT_TENSOR_SCALE = 6, + OUTPUT_TENSOR_ZP = 7, + BIAS_TENSOR = 8, + }; + + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + auto input_shape = vsi::npu::util::GetTensorShape(*input_defs[QLinearConvINPUTS::INPUT_TENSOR]); + auto w_scale_shape = vsi::npu::util::GetTensorShape(*input_defs[QLinearConvINPUTS::WEIGHT_TENSOR_SCALE]); + auto w_shape_dims = vsi::npu::util::GetTensorShape(*input_defs[QLinearConvINPUTS::WEIGHT_TENSOR]).GetDims(); + if (input_shape.NumDimensions() != 4) { + LOGS_DEFAULT(WARNING) << "Not support conv3d&& conv1d yet."; + return false; + } + + if (!graph_viewer.IsConstantInitializer(input_defs[QLinearConvINPUTS::INPUT_TENSOR_SCALE]->Name(), true) || + !graph_viewer.IsConstantInitializer(input_defs[WEIGHT_TENSOR]->Name(), true)) { + LOGS_DEFAULT(WARNING) << "Not support quantization definitions or weights that are not constant yet."; + return false; + } + + if (w_shape_dims[2] > 15) { + LOGS_DEFAULT(WARNING) << "Not support weight kernel with height higher than 15."; + return false; + } + + if (w_scale_shape.Size() != 1 && *input_defs[WEIGHT_TENSOR]->Type() == "tensor(int8)") { + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_viewer.GetConstantInitializer(input_defs[QLinearConvINPUTS::WEIGHT_TENSOR_ZP]->Name(), true); + std::vector w_zp(tensor_proto->dims_size() == 0 ? 1 : tensor_proto->dims()[0]); + + auto status = onnxruntime::utils::UnpackTensor( + *tensor_proto, + tensor_proto->has_raw_data() ? tensor_proto->raw_data().data() : nullptr, + tensor_proto->has_raw_data() ? tensor_proto->raw_data().size() : 0, + w_zp.data(), w_zp.size()); + if (!status.IsOK()) { + LOGS_DEFAULT(ERROR) << "Failed to get data from weight zp tensor."; + return false; + } + if (std::any_of(w_zp.begin(), w_zp.end(), [](int i) { return i != 0; })) { + LOGS_DEFAULT(WARNING) << "Asymmetric perchannel quantization only allows uint8 datatype or int8 with all zero."; + return false; + } + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating QLinearConv Op."; + + NodeAttrHelper helper(node_unit.GetNode()); + auto padtype = helper.Get("auto_pad", std::string("")); + auto group = helper.Get("group", static_cast(1)); + std::vector default_vec = {1, 1, 1, 1}; + auto stride = + helper.Get("strides", default_vec); + auto dilation = + helper.Get("dilations", default_vec); + std::shared_ptr op; + if (padtype != "NOTSET") { // array "pads" is not set + if (group != 1 && group != inputs[1]->GetShape()[3]) { + op = graph_ep->GetGraph() + ->CreateOperation( + vsi::npu::util::GetPadType(padtype), + std::array{stride[1], stride[0]}, + std::array{dilation[1], dilation[0]}, group, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + + } else { + int32_t multiplier = group == 1 ? 0 : inputs[1]->GetShape()[3] / inputs[0]->GetShape()[2]; + op = graph_ep->GetGraph()->CreateOperation( + vsi::npu::util::GetPadType(padtype), + std::array{stride[1], stride[0]}, + std::array{dilation[1], dilation[0]}, multiplier, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + } + } else { + std::vector default_pads(4, 0); + auto pads = helper.Get("pads", default_pads); + if (group != 1 && group != inputs[1]->GetShape()[3]) { + op = graph_ep->GetGraph() + ->CreateOperation( + std::array{pads[1], pads[3], pads[0], pads[2]}, + std::array{stride[1], stride[0]}, + std::array{dilation[1], dilation[0]}, group, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + + } else { + int32_t multiplier = group == 1 ? 0 : inputs[1]->GetShape()[3] / inputs[0]->GetShape()[2]; + op = graph_ep->GetGraph()->CreateOperation( + std::array{pads[1], pads[3], + pads[0], pads[2]}, + std::array{stride[1], stride[0]}, + std::array{dilation[1], dilation[0]}, multiplier, + tim::vx::DataLayout::WHCN, tim::vx::DataLayout::WHIcOc); + } + } + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/qlinearmatmul_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearmatmul_op_builder.h new file mode 100644 index 0000000000000..7447c8b6b0b91 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/qlinearmatmul_op_builder.h @@ -0,0 +1,83 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { + +class QLinearMatMulOpBuilder : public BaseOpBuilder { + enum { + matrixA = 0, + A_scale = 1, + A_zero_point = 2, + matrixB = 3, + B_scale = 4, + B_zero_point = 5, + out_scale = 6, + out_zero_point = 7 + }; + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + auto A_def = input_defs[matrixA]; + auto B_def = input_defs[matrixB]; + for (auto def : input_defs) { + if (def->Name() == A_def->Name() || def->Name() == B_def->Name()) { + continue; + } else { + if (!graph_viewer.IsConstantInitializer(def->Name(), true)) { + LOGS_DEFAULT(WARNING) << "Scale and zero point must be known before setting graph."; + return false; + } + } + } + int64_t A_elements = util::GetTensorShape(*input_defs[A_scale]).Size(); + int64_t B_elements = util::GetTensorShape(*input_defs[B_scale]).Size(); + int64_t Out_elements = util::GetTensorShape(*input_defs[out_scale]).Size(); + if (A_elements > 1 || B_elements > 1 || Out_elements > 1) { + LOGS_DEFAULT(WARNING) << "Per channel quantized input/output is not supported in QLinearMatmul Op."; + return false; + } + + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating QLinearMatmul Op."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/quantize_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/quantize_op_builder.h new file mode 100644 index 0000000000000..63ae491909bdc --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/quantize_op_builder.h @@ -0,0 +1,79 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { + +class QuantizeLinearOpBuilder : public BaseOpBuilder { + enum QuantizeINPUTS { + input_tensor = 0, + scale_tensor = 1, + zero_point_tensor = 2 + }; + + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + auto scale_shape = npu::util::GetTensorShape(*input_defs[QuantizeINPUTS::scale_tensor]); + NodeAttrHelper helper(*node); + if (helper.HasAttr("block_size") && helper.Get("block_size", 0) != 0) { + LOGS_DEFAULT(WARNING) << "Not support block quantization."; + return false; + } + if (!graph_viewer.IsConstantInitializer(input_defs[QuantizeINPUTS::scale_tensor]->Name(), true) || + (input_defs.size() == 3 && !graph_viewer.IsConstantInitializer( + input_defs[QuantizeINPUTS::zero_point_tensor]->Name(), true))) { + LOGS_DEFAULT(WARNING) << "Only support const scale / zero point."; + return false; + } + + if (scale_shape.Size() != 1) { + LOGS_DEFAULT(WARNING) << "Per channel quantized output is not supported in QuantizeLinearOp."; + return false; + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating Quantize Op."; + auto op = graph_ep->GetGraph()->CreateOperation(); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h new file mode 100644 index 0000000000000..3b0a282c5de89 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h @@ -0,0 +1,82 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class ReduceMeanOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + if (*input_defs[0]->Type() == "tensor(int32)") { + LOGS_DEFAULT(WARNING) << "Not support int32 reduce mean yet."; + return false; + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating ReduceMean Op."; + + NodeAttrHelper helper(node_unit.GetNode()); + std::vector def_axes; + auto input_shape_size = inputs[0]->GetShape().size(); + + if (node_unit.SinceVersion() < 18 && helper.HasAttr("axes")) { + def_axes = helper.Get("axes", def_axes); + } else if (inputs.size() > 1) { + def_axes.resize(inputs[1]->GetSpec().GetElementNum()); + inputs[1]->CopyDataFromTensor(def_axes.data()); + } else { + for (int64_t i = 0; i < input_shape_size; ++i) { + def_axes.push_back(i); + } + } + + std::vector axes(def_axes.begin(), def_axes.end()); + axes = util::ReverseAxis(axes, input_shape_size); + + if (helper.HasAttr("noop_with_empty_axes") && inputs.size() == 1 && helper.Get("noop_with_empty_axes", 0) == 1) { + outputs[0] = inputs[0]; + return true; + } + + bool keepdims = helper.Get("keepdims", 1) == 1; + auto op = graph_ep->GetGraph()->CreateOperation(axes, keepdims); + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/resize_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/resize_op_builder.h new file mode 100644 index 0000000000000..8857efe3537ec --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/resize_op_builder.h @@ -0,0 +1,153 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class ResizeOpBuilder : public BaseOpBuilder { + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + auto input_type = node_unit.Inputs()[0].node_arg.Type(); + if (*input_type == "tensor(int64)" || !util::IsTypeSupported(&node_unit.Inputs()[0].node_arg)) { + LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported input type : " + << *input_type; + return false; + } + if (node_unit.SinceVersion() > 10) { + if (node_unit.Inputs().size() > 2 && !Contains(initializers, node_unit.Inputs()[2].node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Scale tensor must be constant."; + return false; + } + if (node_unit.Inputs().size() > 3 && !Contains(initializers, node_unit.Inputs()[3].node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Size tensor must be constant."; + return false; + } + } else { + if (!Contains(initializers, node_unit.Inputs()[1].node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Scale tensor must be constant."; + return false; + } + } + return true; + } + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, const Node* node) const override { + auto shape = vsi::npu::util::GetTensorShape(*node->InputDefs()[0]); + if (shape.NumDimensions() > 4) { + LOGS_DEFAULT(WARNING) << "3D or more dimesions resize is not supported."; + return false; + } + + NodeAttrHelper helper(*node); + if (helper.Get("antialiax", 0) != 0) { + LOGS_DEFAULT(WARNING) << "Antialias attribute is not supported."; + return false; + } + auto& cooridinate = helper.Get("coordinate_transoformation_mode", "half_pixel"); + if (cooridinate != "align_corners" && cooridinate != "half_pixel") { + LOGS_DEFAULT(WARNING) << "Only support half_pixel and align_corners attributes now."; + return false; + } + if (helper.Get("keep_aspect_ratio_policy", "stretch") != "stretch") { + LOGS_DEFAULT(WARNING) << "Not support to keep aspect ratio."; + return false; + } + if (helper.Get("mode", "nearest") == "cubic") { + LOGS_DEFAULT(WARNING) << "Not support the cubic resize type yet."; + return false; + } + if (helper.HasAttr("axes")) { + LOGS_DEFAULT(WARNING) << "Axes-specifying is not support."; + return false; + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Resize Op."; + auto inputs_num = inputs.size(); + bool is_1dresize = inputs[0]->GetShape().size() == 1; + NodeAttrHelper helper(node_unit.GetNode()); + auto onnx_mode = helper.Get("mode", "nearest"); + auto coordinate_transformation = helper.Get("coordinate_transformation_mode", "half_pixel"); + bool is_size_set = helper.HasAttr("size"); + int32_t scale_index = node_unit.SinceVersion() > 10 ? 2 : 1; + + auto resize_type = onnx_mode == "nearest" ? tim::vx::ResizeType::NEAREST_NEIGHBOR : tim::vx::ResizeType::BILINEAR; + bool align_corners = coordinate_transformation == "align_corners"; + bool half_pixel_center = coordinate_transformation == "half_pixel"; + std::shared_ptr op = nullptr; + if (is_1dresize) { + int target_size; + if (is_size_set) { + int64_t onnx_size; + inputs[3]->CopyDataFromTensor(&onnx_size); + target_size = static_cast(onnx_size); + op = graph_ep->GetGraph()->CreateOperation(resize_type, 0.0f, align_corners, + half_pixel_center, target_size); + } else { + float scale; + inputs[scale_index]->CopyDataFromTensor(&scale); + op = graph_ep->GetGraph()->CreateOperation(resize_type, scale, align_corners, + half_pixel_center, 0); + } + } else { + int target_h, target_w; + if (is_size_set) { + std::vector onnx_sizes(inputs[3]->GetShape().size()); + inputs[3]->CopyDataFromTensor(onnx_sizes.data()); + target_h = static_cast(onnx_sizes[1]); + target_w = static_cast(onnx_sizes[0]); + op = graph_ep->GetGraph()->CreateOperation(resize_type, 0.0f, align_corners, + half_pixel_center, target_h, target_w); + } else { + auto input_shape = inputs[0]->GetShape(); + std::vector scales(input_shape.size()); + std::vector out_shape(input_shape.size()); + inputs[scale_index]->CopyDataFromTensor(scales.data()); + for (int i = 0; i < input_shape.size(); i++) { + out_shape[i] = input_shape[i] * scales[input_shape.size() - 1 - i]; + } + op = graph_ep->GetGraph()->CreateOperation(resize_type, 0, align_corners, + half_pixel_center, out_shape[1], out_shape[0]); + } + } + + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/softmax_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/softmax_op_builder.h new file mode 100644 index 0000000000000..dad10c1a02518 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/softmax_op_builder.h @@ -0,0 +1,101 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class SoftmaxOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + NodeAttrHelper helper(*node); + auto axis = helper.Get("axis", -1); + auto input_defs = node->InputDefs(); + auto input_shape = vsi::npu::util::GetTensorShape(*input_defs[0]); + int32_t rank = input_shape.NumDimensions(); + if (axis >= rank || axis < -rank) { + LOGS_DEFAULT(ERROR) << "Axis is invalid in Softmax."; + return false; + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Softmax Op."; + NodeAttrHelper helper(node_unit.GetNode()); + int32_t def_val = node_unit.SinceVersion() < 13 ? 1 : -1; + auto axis = helper.Get("axis", def_val); + + if (def_val == 1) { + // In earlier opset version of softmax, input is coerced into 2D shape + // Attribute "axis" is to describe the axis of the inputs coerced to 2D but not take part in softmax computation + const bool is_2d_shape = inputs[0]->GetShape().size() == 2 ? true : false; + if (!is_2d_shape) { + axis = HandleNegativeAxis(axis, inputs[0]->GetShape().size()); + auto it = inputs[0]->GetShape().end(); + uint32_t last_dim = std::accumulate(it - axis, it, 1, std::multiplies()); + uint32_t first_dim = std::accumulate(inputs[0]->GetShape().begin(), it - axis, 1, std::multiplies()); + auto reshaped_spec = inputs[0]->GetSpec().AsTransientSpec().SetShape( + std::vector{first_dim, last_dim}); + auto reshaped_input = graph_ep->GetGraph()->CreateTensor(reshaped_spec); + auto reshaped_output = graph_ep->GetGraph()->CreateTensor(inputs[0]->GetSpec().AsTransientSpec()); + + auto reshape_input_op = graph_ep->GetGraph()->CreateOperation( + std::vector{first_dim, last_dim}); + auto softmax_op = graph_ep->GetGraph()->CreateOperation(1, 0); + auto reshaped_output_op = graph_ep->GetGraph()->CreateOperation(inputs[0]->GetShape()); + + (*reshape_input_op).BindInputs(inputs).BindOutput(reshaped_input); + (*softmax_op).BindInput(reshaped_input).BindOutput(reshaped_output); + (*reshaped_output_op).BindInput(reshaped_output).BindOutputs(outputs); + + graph_ep->GetOps().push_back(std::move(reshape_input_op)); + graph_ep->GetOps().push_back(std::move(softmax_op)); + graph_ep->GetOps().push_back(std::move(reshaped_output_op)); + } else { + auto op = graph_ep->GetGraph()->CreateOperation(1, 0); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + } + } else { + axis = util::ReverseAxis(axis, inputs[0]->GetShape().size()); + auto op = graph_ep->GetGraph()->CreateOperation(1, static_cast(axis)); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + } + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/squeeze_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/squeeze_op_builder.h new file mode 100644 index 0000000000000..2e1837384618d --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/squeeze_op_builder.h @@ -0,0 +1,88 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class SqueezeOpBuilder : public BaseOpBuilder { + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + auto input_type = node_unit.Inputs()[0].node_arg.Type(); + if (*input_type == "tensor(int64)" || !util::IsTypeSupported(&node_unit.Inputs()[0].node_arg)) { + LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported input type : " + << *input_type; + return false; + } + if (node_unit.SinceVersion() > 11) { + if (node_unit.Inputs().size() > 1 && !Contains(initializers, node_unit.Inputs()[1].node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Only support const axes in Squeeze op."; + return false; + } + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating Squeeze Op."; + + NodeAttrHelper helper(node_unit.GetNode()); + std::vector def_axes; + auto input_shape_size = inputs[0]->GetShape().size(); + + if (node_unit.SinceVersion() < 13 && helper.HasAttr("axes")) { + def_axes = helper.Get("axes", def_axes); + } else if (inputs.size() > 1) { + def_axes.resize(inputs[1]->GetSpec().GetElementNum()); + inputs[1]->CopyDataFromTensor(def_axes.data()); + } else { // if axes is empty from onnx, check input shape to determine + for (int64_t i = 0; i < input_shape_size; ++i) { + if (inputs[0]->GetShape()[i] == 1) { + def_axes.push_back(i); + } + } + } + + std::vector axes(def_axes.begin(), def_axes.end()); + axes = util::ReverseAxis(axes, input_shape_size); + + std::vector timvx_axes(axes.begin(), axes.end()); + + auto op = graph_ep->GetGraph()->CreateOperation(timvx_axes); + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/tensor_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/tensor_op_builder.h new file mode 100644 index 0000000000000..427457b521b61 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/tensor_op_builder.h @@ -0,0 +1,142 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class ReshapeOpBuilder : public BaseOpBuilder { + int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 5; } + + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + auto input = node_unit.Inputs()[0]; + auto shape = node_unit.Inputs()[1]; + if (initializers.end() == initializers.find(shape.node_arg.Name())) { + LOGS_DEFAULT(VERBOSE) << "Target shape of reshape op must be known."; + return false; + } + if (util::IsTypeSupported(&input.node_arg) && util::IsTypeSupported(&shape.node_arg)) { + if (*input.node_arg.Type() != "tensor(int64)") { + return true; + } + } + return false; + } + + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + + NodeAttrHelper helper(*node); + const bool allow_zero = helper.Get("allowzero", 0) == 1; + auto& perm_tensor_proto = *graph_viewer.GetConstantInitializer(input_defs[1]->Name(), true); + std::vector perm(perm_tensor_proto.dims()[0]); + auto status = onnxruntime::utils::UnpackTensor( + perm_tensor_proto, + perm_tensor_proto.has_raw_data() ? perm_tensor_proto.raw_data().data() : nullptr, + perm_tensor_proto.has_raw_data() ? perm_tensor_proto.raw_data().size() : 0, + perm.data(), perm.size()); + + // Check if perm has any 0's when allow zero is enabled. + if (allow_zero && std::find(perm.begin(), perm.end(), 0L) != perm.end()) { + LOGS_DEFAULT(VERBOSE) << "Reshape doesn't support 0 as dimension when allowzero is enabled"; + return false; + } + + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Reshape Op."; + std::vector new_shape(inputs[1]->GetShape()[0]); + inputs[1]->CopyDataFromTensor(new_shape.data()); + for (size_t i = 0; i < new_shape.size(); i++) { + if (new_shape[i] == 0) { + new_shape[i] = inputs[0]->GetShape()[inputs[0]->GetShape().size() - i - 1]; + } + } + + int64_t element_count = std::accumulate(new_shape.begin(), new_shape.end(), static_cast(1), + [&](int64_t a, int64_t b) { + return b == -1 ? a : a * b; + }); + auto negative_it = std::find(new_shape.begin(), new_shape.end(), -1); + if (negative_it != new_shape.end()) { + *negative_it = inputs[0]->GetSpec().GetElementNum() / element_count; + } + + std::vector new_shape_uint32(new_shape.begin(), new_shape.end()); + std::reverse(new_shape_uint32.begin(), new_shape_uint32.end()); + auto op = graph_ep->GetGraph()->CreateOperation(new_shape_uint32); + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +class TransposeOpBuilder : public BaseOpBuilder { + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + auto input_defs = node->InputDefs(); + auto shape_dim = vsi::npu::util::GetTensorShape(*input_defs[0]).NumDimensions(); + NodeAttrHelper helper(*node); + auto perm = helper.Get("perm", std::vector(shape_dim, 1)); + if (perm.size() != shape_dim) { + LOGS_DEFAULT(VERBOSE) << "Size mismatch between perm vector and input shape."; + return false; + } + return true; + } + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Transpose Op."; + std::vector def_val(inputs[0]->GetShape().size()); + for (int64_t i = 0; i < def_val.size(); i++) def_val[i] = def_val.size() - i - 1; + + NodeAttrHelper helper(node_unit.GetNode()); + def_val = helper.Get("perm", def_val); + std::vector timvx_perm; + for (uint32_t i = 0; i < def_val.size(); i++) { + timvx_perm.push_back(def_val.size() - 1 - def_val[def_val.size() - i - 1]); + } + auto op = graph_ep->GetGraph()->CreateOperation(timvx_perm); + (*op).BindInputs(inputs).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/tile_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/tile_op_builder.h new file mode 100644 index 0000000000000..d42624c31557c --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/tile_op_builder.h @@ -0,0 +1,71 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class TileOpBuilder : public BaseOpBuilder { + int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 6; } + + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + auto input = node_unit.Inputs()[0]; + auto multipliers = node_unit.Inputs()[1]; + if (initializers.end() == initializers.find(multipliers.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Multipliers of tile op must be known."; + return false; + } + if (util::IsTypeSupported(&input.node_arg) && util::IsTypeSupported(&multipliers.node_arg)) { + if (*input.node_arg.Type() != "tensor(int64)") { + return true; + } + } + LOGS_DEFAULT(WARNING) << "Input type not supported."; + return false; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Tile Op."; + std::vector multipliers(inputs[1]->GetShape()[0]); + inputs[1]->CopyDataFromTensor(multipliers.data()); + std::reverse(multipliers.begin(), multipliers.end()); + std::vector timvx_multipliers(multipliers.begin(), multipliers.end()); + auto op = graph_ep->GetGraph()->CreateOperation(timvx_multipliers); + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/unsqueeze_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/unsqueeze_op_builder.h new file mode 100644 index 0000000000000..c49c93008b25a --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/unsqueeze_op_builder.h @@ -0,0 +1,89 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class UnsqueezeOpBuilder : public BaseOpBuilder { + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + auto input_type = node_unit.Inputs()[0].node_arg.Type(); + if (*input_type == "tensor(int64)" || !util::IsTypeSupported(&node_unit.Inputs()[0].node_arg)) { + LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported input type : " + << *input_type; + return false; + } + if (node_unit.SinceVersion() > 11 && !Contains(initializers, node_unit.Inputs()[1].node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Only support const axes in Unsqueeze op."; + return false; + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(INFO) << "Creating Unsqueeze Op."; + + NodeAttrHelper helper(node_unit.GetNode()); + std::vector def_axes; + auto input_shape_size = inputs[0]->GetShape().size(); + + if (node_unit.SinceVersion() < 13 && helper.HasAttr("axes")) { + def_axes = helper.Get("axes", def_axes); + } else if (inputs.size() > 1) { + def_axes.resize(inputs[1]->GetSpec().GetElementNum()); + inputs[1]->CopyDataFromTensor(def_axes.data()); + } else { // if axes is empty from onnx, check input shape to determine + for (int64_t i = 0; i < input_shape_size; ++i) { + if (inputs[0]->GetShape()[i] == 1) { + def_axes.push_back(i); + } + } + } + + std::vector axes(def_axes.begin(), def_axes.end()); + axes = util::ReverseAxis(axes, input_shape_size + axes.size()); + + std::vector timvx_axes(inputs[0]->GetShape().begin(), inputs[0]->GetShape().end()); + for (int32_t dim : axes) { + timvx_axes.insert(timvx_axes.begin() + dim, 1); + } + + auto op = graph_ep->GetGraph()->CreateOperation(timvx_axes); + (*op).BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/op_builder.h b/onnxruntime/core/providers/vsinpu/builders/op_builder.h new file mode 100644 index 0000000000000..d81a478149c6b --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/op_builder.h @@ -0,0 +1,48 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include "core/graph/graph_viewer.h" +#include "core/framework/node_unit.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +class GraphEP; + +class IOpBuilder { + public: + IOpBuilder() {} + virtual ~IOpBuilder() {} + virtual bool IsSupported(const onnxruntime::GraphViewer& graph_viewer, + const NodeUnit& node_unit) const { + return true; + } + virtual bool BuildOp(GraphEP* graph_ep, + const onnxruntime::GraphViewer& graph_viewer, + const NodeUnit& node_unit) = 0; +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h b/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h new file mode 100644 index 0000000000000..3a9190d8cb03a --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h @@ -0,0 +1,133 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include +#include "impl/activation_op_builder.h" +#include "impl/conv_op_builder.h" +#include "impl/elementwise_op_builder.h" +#include "impl/gemm_op_builder.h" +#include "impl/pool_op_builder.h" +#include "impl/qlinearconv_op_builder.h" +#include "impl/flatten_op_builder.h" +#include "impl/matmul_op_builder.h" +#include "impl/tensor_op_builder.h" +#include "impl/concat_op_builder.h" +#include "impl/softmax_op_builder.h" +#include "impl/norm_op_builder.h" +#include "impl/clip_op_builder.h" +#include "impl/reduce_op_builder.h" +#include "impl/quantize_op_builder.h" +#include "impl/dequantize_op_builder.h" +#include "impl/qlinearmatmul_op_builder.h" +#include "impl/qlinear_binary_op_builder.h" +#include "impl/qlinearconcat_op_builder.h" +#include "impl/gather_op_builder.h" +#include "impl/tile_op_builder.h" +#include "impl/squeeze_op_builder.h" +#include "impl/unsqueeze_op_builder.h" +#include "impl/resize_op_builder.h" +#include "impl/cast_op_builder.h" +namespace onnxruntime { +namespace vsi { +namespace npu { +using createIOpBuildItemFunc = std::function()>; +using OpBuildItemType = std::map>; + +static const std::map reg = { +#define REGISTER_OP_BUILDER(ONNX_NODE_TYPE, BUILDER_TYPE) \ + { \ + ONNX_NODE_TYPE, [] { return std::make_unique(); } \ + } + + REGISTER_OP_BUILDER("Add", AddOpBuilder), + REGISTER_OP_BUILDER("Sub", SubOpBuilder), + REGISTER_OP_BUILDER("Mul", MulOpBuilder), + REGISTER_OP_BUILDER("Div", DivOpBuilder), + REGISTER_OP_BUILDER("Abs", AbsOpBuilder), + REGISTER_OP_BUILDER("Pow", PowOpBuilder), + REGISTER_OP_BUILDER("Sqrt", SqrtOpBuilder), + REGISTER_OP_BUILDER("Exp", ExpOpBuilder), + REGISTER_OP_BUILDER("Floor", FloorOpBuilder), + REGISTER_OP_BUILDER("Log", LogOpBuilder), + REGISTER_OP_BUILDER("Sin", SinOpBuilder), + REGISTER_OP_BUILDER("Conv", ConvOpBuilder), + REGISTER_OP_BUILDER("Gemm", GemmOpBuilder), + REGISTER_OP_BUILDER("Relu", ReluOpBuilder), + REGISTER_OP_BUILDER("LeakyRelu", LeakyReluOpBuilder), + REGISTER_OP_BUILDER("Tanh", TanhOpBuilder), + REGISTER_OP_BUILDER("Sigmoid", SigmoidOpBuilder), + REGISTER_OP_BUILDER("HardSigmoid", HardSigmoidOpBuilder), + REGISTER_OP_BUILDER("HardSwish", HardSwishOpBuilder), + REGISTER_OP_BUILDER("GlobalAveragePool", GlobalAveragePoolOpBuilder), + REGISTER_OP_BUILDER("QLinearConv", QLinearConvOpBuilder), + REGISTER_OP_BUILDER("Flatten", FlattenOpBuilder), + REGISTER_OP_BUILDER("MatMul", MatMulOpBuilder), + REGISTER_OP_BUILDER("GlobalMaxPool", GlobalMaxPoolOpBuilder), + REGISTER_OP_BUILDER("AveragePool", AveragePoolOpBuilder), + REGISTER_OP_BUILDER("MaxPool", MaxPoolOpBuilder), + REGISTER_OP_BUILDER("Reshape", ReshapeOpBuilder), + REGISTER_OP_BUILDER("Concat", ConcatOpBuilder), + REGISTER_OP_BUILDER("Softmax", SoftmaxOpBuilder), + REGISTER_OP_BUILDER("Transpose", TransposeOpBuilder), + REGISTER_OP_BUILDER("BatchNormalization", BatchNormOpBuilder), + REGISTER_OP_BUILDER("Clip", ClipOpBuilder), + REGISTER_OP_BUILDER("ReduceMean", ReduceMeanOpBuilder), + REGISTER_OP_BUILDER("QuantizeLinear", QuantizeLinearOpBuilder), + REGISTER_OP_BUILDER("DequantizeLinear", DequantizeLinearOpBuilder), + REGISTER_OP_BUILDER("QLinearMatMul", QLinearMatMulOpBuilder), + REGISTER_OP_BUILDER("QLinearAdd", QLinearAddOpBuilder), + REGISTER_OP_BUILDER("QLinearMul", QLinearMulOpBuilder), + REGISTER_OP_BUILDER("QLinearConcat", QLinearConcatOpBuilder), + REGISTER_OP_BUILDER("Gather", GatherOpBuilder), + REGISTER_OP_BUILDER("Tile", TileOpBuilder), + REGISTER_OP_BUILDER("Squeeze", SqueezeOpBuilder), + REGISTER_OP_BUILDER("Unsqueeze", UnsqueezeOpBuilder), + REGISTER_OP_BUILDER("Resize", ResizeOpBuilder), + REGISTER_OP_BUILDER("Cast", CastOpBuilder), + +#undef REGISTER_OP_BUILDER +}; + +template +struct OpBuildConstructor { + T supported_builtins; + OpBuildConstructor( + const std::map reg) { + LOGS_DEFAULT(INFO) << "Initialize supported ops"; + for (const auto& kv : reg) { + supported_builtins.insert(std::make_pair(kv.first, kv.second())); + } + } +}; + +inline const OpBuildItemType& SupportedBuiltinOps() { + static OpBuildConstructor c(reg); + return c.supported_builtins; +} +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/patches/AccuracyCorrection.patch b/onnxruntime/core/providers/vsinpu/patches/AccuracyCorrection.patch new file mode 100644 index 0000000000000..d44190101d9fa --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/patches/AccuracyCorrection.patch @@ -0,0 +1,26 @@ +diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc +index 47c18c478d..93b44501cd 100644 +--- a/onnxruntime/test/providers/checkers.cc ++++ b/onnxruntime/test/providers/checkers.cc +@@ -195,7 +195,7 @@ struct TensorCheck { + // For any other EPs, we still expect an exact match for the results + // TODO: Verify if DML can possibly have a ROUNDING_MODE parameter and conform to the other EPs #41968513 + if ((provider_type == kNnapiExecutionProvider || provider_type == kDmlExecutionProvider || +- provider_type == kXnnpackExecutionProvider) && ++ provider_type == kXnnpackExecutionProvider || provider_type == kVSINPUExecutionProvider) && + (has_abs_err || has_rel_err)) { + double threshold = has_abs_err ? *(params.absolute_error) + : 0.0; +diff --git a/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc b/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc +index 2bc0df5e36..7beb78c2ff 100644 +--- a/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc ++++ b/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc +@@ -498,7 +498,7 @@ class QLinearConvOpTester { + // NOTE, for now the tolerance will only apply if the NNAPI is actually used, + // if for any reason the execution falls back to CPU, we still expect an exact match + // See, 'void Check(...' in onnxruntime/test/providers/provider_test_utils.cc +-#if defined(USE_NNAPI) || defined(USE_DML) ++#if defined(USE_NNAPI) || defined(USE_DML) || defined(USE_VSINPU) + // TODO: Verify if DML can possibly have a ROUNDING_MODE parameter and conform to the other EPs #41968513 + abs_error = 1.0f; + #endif diff --git a/onnxruntime/core/providers/vsinpu/patches/local_testing_record_res.patch b/onnxruntime/core/providers/vsinpu/patches/local_testing_record_res.patch new file mode 100644 index 0000000000000..e118ee104912f --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/patches/local_testing_record_res.patch @@ -0,0 +1,343 @@ +diff --git a/onnxruntime/test/onnx/dataitem_request.cc b/onnxruntime/test/onnx/dataitem_request.cc +index 1ee302d5d5..5c2dd5ab00 100644 +--- a/onnxruntime/test/onnx/dataitem_request.cc ++++ b/onnxruntime/test/onnx/dataitem_request.cc +@@ -135,6 +135,7 @@ std::pair DataTaskRequestContext::RunImpl() { + } + + EXECUTE_RESULT res = EXECUTE_RESULT::SUCCESS; ++ int32_t out_idx = 0; + for (auto& output : expected_output_values) { + const std::string& output_name = output.first; + OrtValue* expected_output_value = output.second; // Automatic cast +@@ -170,7 +171,7 @@ std::pair DataTaskRequestContext::RunImpl() { + } else { // Both expect and actual OrtValues are not None, proceed with data checking + ret = + CompareOrtValue(*actual_output_value, *expected_output_value, per_sample_tolerance, +- relative_per_sample_tolerance, post_procesing); ++ relative_per_sample_tolerance, post_procesing, out_idx); + } + } else { // Expected output is None, ensure that the received output OrtValue is None as well + if (actual_output_value->IsAllocated()) { +@@ -223,9 +224,10 @@ std::pair DataTaskRequestContext::RunImpl() { + if (compare_result != COMPARE_RESULT::SUCCESS && !ret.second.empty()) { + LOGS_DEFAULT(ERROR) << test_case_.GetTestCaseName() << ":output=" << output_name << ":" << ret.second; + } +- if (compare_result != COMPARE_RESULT::SUCCESS) { +- break; +- } ++ // if (compare_result != COMPARE_RESULT::SUCCESS) { ++ // break; ++ // } ++ out_idx ++; + } + return std::make_pair(res, spent_time_); + } +diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc +index f1a7240ea3..436031dfa8 100644 +--- a/onnxruntime/test/providers/checkers.cc ++++ b/onnxruntime/test/providers/checkers.cc +@@ -154,6 +154,7 @@ struct TensorCheck { + } + + const bool has_abs_err = params.absolute_error.has_value(); ++ const int8_t default_abs_err = 1; + if (has_abs_err) { + double threshold = *(params.absolute_error); + +@@ -162,7 +163,8 @@ struct TensorCheck { + } + } else { + for (int i = 0; i < size; ++i) { +- EXPECT_EQ(cur_expected[i], cur_actual[i]) << "i:" << i; ++ // EXPECT_EQ(cur_expected[i], cur_actual[i]) << "i:" << i; ++ EXPECT_NEAR(cur_expected[i], cur_actual[i], default_abs_err) << "i:" << i; + } + } + } +diff --git a/onnxruntime/test/util/compare_ortvalue.cc b/onnxruntime/test/util/compare_ortvalue.cc +index 3d53d4a3a0..8129af1820 100644 +--- a/onnxruntime/test/util/compare_ortvalue.cc ++++ b/onnxruntime/test/util/compare_ortvalue.cc +@@ -138,11 +138,75 @@ std::pair CompareFloatResult(const Tensor& outvalue + return res; + } + ++template ++std::pair CompareFloatResult(const Tensor& outvalue, const Tensor& expected_value, ++ double per_sample_tolerance, ++ double relative_per_sample_tolerance, bool post_processing, int32_t out_idx) { ++ const size_t size1 = static_cast(expected_value.Shape().Size()); ++ const FLOAT_TYPE* expected_output = expected_value.Data(); ++ const FLOAT_TYPE* real_output = outvalue.Data(); ++ ++ std::string expected_name = "expected_res"+ std::to_string(out_idx)+ ".txt"; ++ std::string npures_name = "npu_res"+ std::to_string(out_idx)+ ".txt"; ++ std::ofstream expected_res(expected_name), npu_res(npures_name); ++ for(size_t i = 0 ; i < size1; i++){ ++ expected_res << expected_output[i] << std::endl; ++ npu_res << real_output[i] << std::endl; ++ } ++ expected_res.close(); ++ npu_res.close(); ++ ++ std::pair res = std::make_pair(COMPARE_RESULT::SUCCESS, ""); ++ double max_diff = 0; ++ size_t diff_count = 0; ++ for (size_t di = 0; di != size1; ++di) { ++ const double real_value = ++ post_processing ? std::max(0.0, std::min(255.0, real_output[di])) : real_output[di]; ++ const double diff = std::fabs(expected_output[di] - real_value); ++ const double tol = per_sample_tolerance + relative_per_sample_tolerance * std::fabs(expected_output[di]); ++ if (!IsResultCloselyMatch(real_value, expected_output[di], diff, tol)) { ++ res.first = COMPARE_RESULT::RESULT_DIFFERS; ++ // update error message if this is a larger diff ++ if (diff > max_diff || (std::isnan(diff) && !std::isnan(max_diff))) { ++ int64_t expected_int = 0; ++ int64_t real_int = 0; ++ memcpy(&expected_int, &expected_output[di], sizeof(FLOAT_TYPE)); ++ memcpy(&real_int, &real_output[di], sizeof(FLOAT_TYPE)); ++ ++ std::ostringstream oss; ++ oss << std::hex << "expected " << expected_output[di] << " (" << expected_int << "), got " << real_value << " (" ++ << real_int << ")" ++ << ", diff: " << diff << ", tol=" << tol << std::dec << " idx=" << di << "."; ++ res.second = oss.str(); ++ max_diff = diff; ++ } ++ ++diff_count; ++ } ++ } ++ ++ if (res.first == COMPARE_RESULT::SUCCESS) return res; ++ ++ std::ostringstream oss; ++ oss << res.second << " " << diff_count << " of " << size1 << " differ"; ++ res.second = oss.str(); ++ return res; ++} ++ ++ + template +-std::pair IsResultExactlyMatch(const Tensor& outvalue, const Tensor& expected_value) { ++std::pair IsResultExactlyMatch(const Tensor& outvalue, const Tensor& expected_value, int32_t out_idx) { + const size_t size1 = static_cast(expected_value.Shape().Size()); + const T* expected_output = expected_value.Data(); + const T* real_output = outvalue.Data(); ++ std::string expected_name = "expected_res"+ std::to_string(out_idx)+ ".txt"; ++ std::string npures_name = "npu_res"+ std::to_string(out_idx)+ ".txt"; ++ std::ofstream expected_res(expected_name), npu_res(npures_name); ++ for(size_t i = 0 ; i < size1; i++){ ++ expected_res << expected_output[i] << std::endl; ++ npu_res << real_output[i] << std::endl; ++ } ++ expected_res.close(); ++ npu_res.close(); + for (size_t di = 0; di != size1; ++di) { + if (expected_output[di] != real_output[di]) { + std::ostringstream oss; +@@ -201,7 +265,7 @@ std::pair CompareBFloat16Result(const Tensor& outva + + std::pair CompareTwoTensors(const Tensor& outvalue, const Tensor& expected_tensor, + double per_sample_tolerance, +- double relative_per_sample_tolerance, bool post_processing) { ++ double relative_per_sample_tolerance, bool post_processing, int32_t out_idx) { + if (expected_tensor.Shape() != outvalue.Shape()) { + std::ostringstream oss; + oss << "shape mismatch, expect " << expected_tensor.Shape().ToString() << " got " << outvalue.Shape().ToString(); +@@ -209,30 +273,30 @@ std::pair CompareTwoTensors(const Tensor& outvalue, + } + if (outvalue.IsDataType()) { + return CompareFloatResult(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, +- post_processing); ++ post_processing, out_idx); + } else if (outvalue.IsDataType()) { + return CompareFloatResult(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, +- post_processing); ++ post_processing, out_idx); + } else if (outvalue.IsDataTypeString()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { +- return IsResultExactlyMatch(outvalue, expected_tensor); ++ return IsResultExactlyMatch(outvalue, expected_tensor, out_idx); + } else if (outvalue.IsDataType()) { + return CompareFloat16Result(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, + post_processing); +@@ -300,7 +364,7 @@ std::pair CompareSparseTensors(const SparseTensor& + " actual: ", actual.Format()); + + TEST_RETURN_IF_ERROR(CompareTwoTensors(actual.Values(), expected.Values(), +- per_sample_tolerance, relative_per_sample_tolerance, post_processing), ++ per_sample_tolerance, relative_per_sample_tolerance, post_processing, 0), + "While comparing sparse values"); + + if (actual.Format() == SparseFormat::kCoo) { +@@ -308,16 +372,16 @@ std::pair CompareSparseTensors(const SparseTensor& + auto expected_view = expected.AsCoo(); + + TEST_RETURN_IF_ERROR(CompareTwoTensors(actual_view.Indices(), expected_view.Indices(), +- per_sample_tolerance, relative_per_sample_tolerance, post_processing), ++ per_sample_tolerance, relative_per_sample_tolerance, post_processing, 0), + "Comparing COO indices"); + } else if (actual.Format() == SparseFormat::kCsrc) { + auto actual_view = actual.AsCsr(); + auto expected_view = expected.AsCsr(); + TEST_RETURN_IF_ERROR(CompareTwoTensors(actual_view.Inner(), expected_view.Inner(), +- per_sample_tolerance, relative_per_sample_tolerance, post_processing), ++ per_sample_tolerance, relative_per_sample_tolerance, post_processing, 0), + "Comparing Csr(c) inner indices"); + TEST_RETURN_IF_ERROR(CompareTwoTensors(actual_view.Outer(), expected_view.Outer(), +- per_sample_tolerance, relative_per_sample_tolerance, post_processing), ++ per_sample_tolerance, relative_per_sample_tolerance, post_processing, 0), + "Comparing Csr(c) outer indices"); + } + +@@ -385,7 +449,83 @@ std::pair CompareOrtValue(const OrtValue& o, const + return std::make_pair(COMPARE_RESULT::TYPE_MISMATCH, oss.str()); + } + return CompareTwoTensors(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, +- post_processing); ++ post_processing, 0); ++ } else if (o.IsSparseTensor()) { ++#if !defined(DISABLE_SPARSE_TENSORS) ++ TEST_RETURN_IF_NOT(expected_mlvalue.IsSparseTensor(), COMPARE_RESULT::TYPE_MISMATCH, ++ "SparseTensor is not expected as output"); ++ TEST_RETURN_IF_ERROR(CompareSparseTensors(o.Get(), expected_mlvalue.Get(), ++ per_sample_tolerance, relative_per_sample_tolerance, ++ post_processing), ++ "while comaring sparse tensors"); ++#endif ++ return std::make_pair(COMPARE_RESULT::SUCCESS, ""); ++ } else if (o.IsTensorSequence()) { ++ auto& expected_tensor_seq = expected_mlvalue.Get(); ++ auto expected_tensor_count = expected_tensor_seq.Size(); ++ ++ auto& actual_tensor_seq = o.Get(); ++ auto actual_tensor_count = actual_tensor_seq.Size(); ++ ++ if (expected_tensor_count != actual_tensor_count) { ++ std::ostringstream oss; ++ oss << "expected tensor count in the sequence: " << expected_tensor_count << " got " ++ << actual_tensor_count; ++ return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); ++ } ++ ++ if (!expected_tensor_seq.IsSameDataType(actual_tensor_seq)) { ++ std::ostringstream oss; ++ oss << "expected tensor type in the sequence: " << expected_tensor_seq.DataType() << " got " ++ << actual_tensor_seq.DataType(); ++ return std::make_pair(COMPARE_RESULT::TYPE_MISMATCH, oss.str()); ++ } ++ ++ for (size_t i = 0; i < expected_tensor_count; ++i) { ++ auto res = CompareTwoTensors(actual_tensor_seq.Get(i), expected_tensor_seq.Get(i), per_sample_tolerance, relative_per_sample_tolerance, ++ post_processing,0); ++ if (res.first != COMPARE_RESULT::SUCCESS) { ++ return res; ++ } ++ } ++ ++ return std::make_pair(COMPARE_RESULT::SUCCESS, ""); ++ ++ } else { ++ // Maps ++#if !defined(DISABLE_ML_OPS) ++ if (o.Type() == DataTypeImpl::GetType()) { ++ return CompareSeqOfMapToFloat(o.Get(), expected_mlvalue.Get(), ++ per_sample_tolerance, relative_per_sample_tolerance, post_processing); ++ } ++ if (o.Type() == DataTypeImpl::GetType()) { ++ return CompareSeqOfMapToFloat(o.Get(), expected_mlvalue.Get(), ++ per_sample_tolerance, relative_per_sample_tolerance, post_processing); ++ } ++ return std::make_pair(COMPARE_RESULT::NOT_SUPPORT, ""); ++#else ++ return std::make_pair(COMPARE_RESULT::NOT_SUPPORT, "Map type is not supported in this build."); ++#endif ++ } ++} ++ ++std::pair CompareOrtValue(const OrtValue& o, const OrtValue& expected_mlvalue, ++ double per_sample_tolerance, ++ double relative_per_sample_tolerance, bool post_processing, int32_t out_idx) { ++ if (o.Type() != expected_mlvalue.Type()) { ++ return std::make_pair(COMPARE_RESULT::TYPE_MISMATCH, ""); ++ } ++ if (o.IsTensor()) { ++ const Tensor& outvalue = o.Get(); ++ const Tensor& expected_tensor = expected_mlvalue.Get(); ++ if (outvalue.DataType() != expected_tensor.DataType()) { ++ std::ostringstream oss; ++ oss << "expect " << ElementTypeToString(expected_tensor.DataType()) << " got " ++ << ElementTypeToString(outvalue.DataType()); ++ return std::make_pair(COMPARE_RESULT::TYPE_MISMATCH, oss.str()); ++ } ++ return CompareTwoTensors(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, ++ post_processing, out_idx); + } else if (o.IsSparseTensor()) { + #if !defined(DISABLE_SPARSE_TENSORS) + TEST_RETURN_IF_NOT(expected_mlvalue.IsSparseTensor(), COMPARE_RESULT::TYPE_MISMATCH, +@@ -419,7 +559,7 @@ std::pair CompareOrtValue(const OrtValue& o, const + + for (size_t i = 0; i < expected_tensor_count; ++i) { + auto res = CompareTwoTensors(actual_tensor_seq.Get(i), expected_tensor_seq.Get(i), per_sample_tolerance, relative_per_sample_tolerance, +- post_processing); ++ post_processing, out_idx); + if (res.first != COMPARE_RESULT::SUCCESS) { + return res; + } +diff --git a/onnxruntime/test/util/include/compare_ortvalue.h b/onnxruntime/test/util/include/compare_ortvalue.h +index 24b74b9002..8269346528 100644 +--- a/onnxruntime/test/util/include/compare_ortvalue.h ++++ b/onnxruntime/test/util/include/compare_ortvalue.h +@@ -24,7 +24,9 @@ enum class COMPARE_RESULT { SUCCESS, + std::pair CompareOrtValue(const OrtValue& real, const OrtValue& expected, + double per_sample_tolerance, + double relative_per_sample_tolerance, bool post_processing); +- ++std::pair CompareOrtValue(const OrtValue& real, const OrtValue& expected, ++ double per_sample_tolerance, ++ double relative_per_sample_tolerance, bool post_processing, int32_t out_idx); + // verify if the 'value' matches the 'expected' ValueInfoProto. 'value' is a model output + std::pair VerifyValueInfo(const ONNX_NAMESPACE::ValueInfoProto& expected, + const OrtValue* value); +diff --git a/onnxruntime/test/util/include/test/compare_ortvalue.h b/onnxruntime/test/util/include/test/compare_ortvalue.h +index 545df706c9..170eb9dc4c 100644 +--- a/onnxruntime/test/util/include/test/compare_ortvalue.h ++++ b/onnxruntime/test/util/include/test/compare_ortvalue.h +@@ -28,7 +28,9 @@ enum class COMPARE_RESULT { + std::pair CompareOrtValue(const OrtValue& real, const OrtValue& expected, + double per_sample_tolerance, + double relative_per_sample_tolerance, bool post_processing); +- ++std::pair CompareOrtValue(const OrtValue& real, const OrtValue& expected, ++ double per_sample_tolerance, ++ double relative_per_sample_tolerance, bool post_processing, int32_t out_idx); + // Compare two OrtValue numerically equal or not. The difference with CompareOrtValue is that this function + // will only check the numerical values of the OrtValue, and ignore the type, shape, etc. + // diff --git a/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch new file mode 100644 index 0000000000000..a9d02765cf34d --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch @@ -0,0 +1,34 @@ +diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake +index e0ccc504d7..6c5aa6ea53 100644 +--- a/cmake/onnxruntime_mlas.cmake ++++ b/cmake/onnxruntime_mlas.cmake +@@ -335,7 +335,7 @@ else() + ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ) +- if (NOT APPLE) ++ if (NOT APPLE AND NOT onnxruntime_USE_VSINPU) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S +diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h +index fd6b3df934..f81f1c42b6 100644 +--- a/onnxruntime/core/mlas/inc/mlas.h ++++ b/onnxruntime/core/mlas/inc/mlas.h +@@ -79,6 +79,7 @@ Abstract: + + #if (!defined(_MSC_VER)) || (_MSC_VER >= 1930) + #if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) ++#if !defined(USE_VSINPU) + #if !defined(__APPLE__) + // Had to temporary disable fp16 under APPLE ARM64, as compiling + // the source files require a hardware specific compilation flag. +@@ -87,7 +88,8 @@ Abstract: + + #define MLAS_F16VEC_INTRINSICS_SUPPORTED + +-#endif // ++#endif ++#endif // + #endif // ARM64 + #endif // Visual Studio 16 or earlier does not support fp16 intrinsic diff --git a/onnxruntime/core/providers/vsinpu/patches/test_scripts/compare_cosine_sim.py b/onnxruntime/core/providers/vsinpu/patches/test_scripts/compare_cosine_sim.py new file mode 100644 index 0000000000000..e4e9b44fdc252 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/patches/test_scripts/compare_cosine_sim.py @@ -0,0 +1,29 @@ +import sys + +import numpy as np +from numpy.linalg import norm + + +def read_values(filename): + with open(filename) as file: + values = np.array([float(line.strip()) for line in file]) + return values + + +def cosine_similarity(vec1, vec2): + return np.dot(vec1, vec2) / (norm(vec1) * norm(vec2)) + + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: python cosine_similarity.py ") + sys.exit(1) + + file1 = sys.argv[1] + file2 = sys.argv[2] + + vec1 = read_values(file1) + vec2 = read_values(file2) + + similarity = cosine_similarity(vec1, vec2) + print(f"Cosine Similarity: {similarity}") diff --git a/onnxruntime/core/providers/vsinpu/patches/test_scripts/compare_topn.py b/onnxruntime/core/providers/vsinpu/patches/test_scripts/compare_topn.py new file mode 100644 index 0000000000000..cde75b7f18c1e --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/patches/test_scripts/compare_topn.py @@ -0,0 +1,34 @@ +import sys + + +def read_values(filename): + with open(filename) as file: + values = [(float(line.strip()), i + 1) for i, line in enumerate(file)] + return values + + +def top_n(values, N): + return sorted(values, key=lambda x: x[0], reverse=True)[:N] + + +def compare_files(cpu_file, npu_file, N): + cpu_values = read_values(cpu_file) + npu_values = read_values(npu_file) + + cpu_topn = top_n(cpu_values, N) + npu_topn = top_n(npu_values, N) + + print(f"Top-{N} values in {cpu_file}: {cpu_topn}") + print(f"Top-{N} values in {npu_file}: {npu_topn}") + + +if __name__ == "__main__": + if len(sys.argv) != 4: + print("Usage: python compare_topn.py ") + sys.exit(1) + + N = int(sys.argv[1]) + cpu_file = sys.argv[2] + npu_file = sys.argv[3] + + compare_files(cpu_file, npu_file, N) diff --git a/onnxruntime/core/providers/vsinpu/patches/test_scripts/result_compare.sh b/onnxruntime/core/providers/vsinpu/patches/test_scripts/result_compare.sh new file mode 100644 index 0000000000000..c27af51c26799 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/patches/test_scripts/result_compare.sh @@ -0,0 +1,23 @@ +#!/bin/bash +res_file_dir=$1 +output_num=$2 + +# specifying N value +N=5 + +for i in $(seq 0 $((output_num-1))); +do + # 构建文件名 + golden_file="${res_file_dir}/expected_res${i}.txt" + npu_file="${res_file_dir}/npu_res${i}.txt" + + echo "Comparing Top-${N} for the output_${i}" + python3 compare_topn.py $N $golden_file $npu_file + + echo "--------------------------------" + + echo "Comparing Cosine Similarity for output_${i}:" + python3 compare_cosine_sim.py $golden_file $npu_file + + echo "" +done diff --git a/onnxruntime/core/providers/vsinpu/symbols.txt b/onnxruntime/core/providers/vsinpu/symbols.txt new file mode 100644 index 0000000000000..d69c92692f5fe --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/symbols.txt @@ -0,0 +1 @@ +OrtSessionOptionsAppendExecutionProvider_VSINPU diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc new file mode 100644 index 0000000000000..e51b0713ea41d --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc @@ -0,0 +1,296 @@ + +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include "core/providers/vsinpu/vsinpu_ep_graph.h" +#include "core/providers/vsinpu/builders/op_builder_factory.h" +#include "core/providers/vsinpu/vsinpu_util.h" +#include "core/framework/node_unit.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" + +namespace onnxruntime { + +namespace vsi { +namespace npu { +GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(graph_viewer) { + Prepare(); + context_ = tim::vx::Context::Create(); + graph_ = context_->CreateGraph(); + compiled_ = false; +} + +bool GraphEP::Prepare() { + std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_); + for (const auto& node_unit : node_unit_holder_) { + auto quant_op_type = util::GetQuantizedOpType(*node_unit); + + // Not a qlinear op or qdq node group + if (quant_op_type == util::QuantizedOpType::Unknown) + continue; + + const auto add_quantized_input = + [&all_quantized_op_inputs = all_quantized_op_inputs_](const NodeUnit& node_unit, size_t input_idx) { + const auto& input_name = node_unit.Inputs()[input_idx].node_arg.Name(); + all_quantized_op_inputs[input_name].push_back(&node_unit); + }; + + // All quantized ops EXCEPT QuantizeLinear has quantized input + if (quant_op_type != util::QuantizedOpType::QuantizeLinear) { + add_quantized_input(*node_unit, 0); + } + + if (util::IsQuantizedBinaryOp(quant_op_type)) { + add_quantized_input(*node_unit, 1); + if (util::IsQuantizedConv(quant_op_type) && node_unit->Inputs().size() == 3) { + add_quantized_input(*node_unit, 2); + } + } + } // All quantized inputs is recorded + return true; +} + +bool GraphEP::SupportedOp(const onnxruntime::GraphViewer& graph_viewer, + const NodeUnit& node_unit) { + const auto& supported_builtins = vsi::npu::SupportedBuiltinOps(); + const auto& target_node = node_unit.GetNode(); + const auto& it = supported_builtins.find(target_node.OpType()); + if (supported_builtins.end() != it) { + return it->second->IsSupported(graph_viewer, node_unit); + } + LOGS_DEFAULT(WARNING) << "Fallback unsupported op (node_unit) " << node_unit.OpType() + << " to cpu."; + return false; +} + +bool GraphEP::IsNodeSupportedInGroup(const NodeUnit& node_unit, const GraphViewer& graph_viewer) { + return SupportedOp(graph_viewer, node_unit); +} + +const NodeUnit& GraphEP::GetNodeUnit(const Node* node) const { + const auto node_unit_it = node_unit_map_.find(node); + ORT_ENFORCE(node_unit_it != node_unit_map_.end(), "Node does not have corresponding NodeUnit."); + return *node_unit_it->second; +} + +void GraphEP::UpdateTensorMap(const std::string& name, const std::shared_ptr& dst_tensor) { + auto it = tensors_.find(name); + if (it != tensors_.end()) { + it->second = dst_tensor; + } + for (auto& IO : graph_inputs_) { + if (IO->name == name) { + IO->tensor = dst_tensor; + break; + } + } + for (auto& IO : graph_outputs_) { + if (IO->name == name) { + IO->tensor = dst_tensor; + break; + } + } +} + +std::shared_ptr GraphEP::ConstructNodeIO(const std::shared_ptr& op, std::vector input_arg, std::vector output_arg) { + auto info = std::make_shared(); + info->op_ = op; + std::vector input_names, output_names; + if (input_arg.empty()) { + info->input_names_ = std::vector(); + } else { + input_names.reserve(input_arg.size()); + std::transform(input_arg.begin(), input_arg.end(), std::back_inserter(input_names), + [](const NodeArg* node) -> std::string { + return node->Name(); + }); + info->input_names_ = input_names; + } + if (output_arg.empty()) { + info->output_names_ = std::vector(); + } else { + output_names.reserve(output_arg.size()); + std::transform(output_arg.begin(), output_arg.end(), std::back_inserter(output_names), + [](const NodeArg* node) -> std::string { + return node->Name(); + }); + info->output_names_ = output_names; + } + + return info; +} + +bool GraphEP::BindTensors(const std::shared_ptr& nodeio_info) { + auto op = nodeio_info->op_; + auto input_names = nodeio_info->input_names_; + auto output_names = nodeio_info->output_names_; + if (!input_names.empty()) { + for (auto& name : input_names) { + if (tensors_.find(name) == tensors_.end() || tensors_[name] == nullptr) { + LOGS_DEFAULT(ERROR) << "Input tensor not defined or not found!"; + return false; + } + (*op).BindInput(tensors_[name]); + } + } + if (!output_names.empty()) { + for (auto& name : output_names) { + if (tensors_.find(name) == tensors_.end() || tensors_[name] == nullptr) { + LOGS_DEFAULT(ERROR) << "Output tensor not defined or not found!"; + return false; + } + (*op).BindOutput(tensors_[name]); + } + } + return true; +} + +std::shared_ptr GraphEP::MapTIMVXTensor( + std::shared_ptr& graph, const NodeUnitIODef nudef, + const NodeUnit& node_unit, + const GraphViewer* graph_viewer, tim::vx::TensorAttribute attribute) { + const auto& arg = nudef.node_arg; + + if (tensors_.end() != tensors_.find(nudef.node_arg.Name())) { + // if (!quant_param.has_value() || quant_param.has_value() && tensors_[arg.Name()]->GetSpec().GetQuantization().Type() != tim::vx::QuantType::NONE) + return tensors_.find(arg.Name())->second; + } + auto shape = vsi::npu::util::OnnxShapeToTIMVXShape(vsi::npu::util::GetTensorShape(arg)); + std::reverse(shape.begin(), shape.end()); + tim::vx::DataType dt = vsi::npu::util::OnnxDtypeToTIMVXDtype(arg.Type()); + tim::vx::TensorSpec spec = tim::vx::TensorSpec(dt, shape, attribute); + + // Tensors have same name may not have same status of quant_param existence, such as QLinearConv->MaxPool->QLinearConv + // Maxpool output tensor is not set quantization at first pass + bool is_qtensor = nudef.quant_param.has_value() || Contains(all_quantized_op_inputs_, arg.Name()); + if (is_qtensor) { + float scale = 0.0f; + int32_t zp = 0; + std::optional> scales; + std::optional> zps; + if (nudef.quant_param.has_value()) { + util::GetQuantizationScaleAndZeroPoint(graph_viewer_.GetAllInitializedTensors(), + nudef, node_unit.ModelPath(), + scale, zp, scales, zps); + } else { + auto target_nodeunit = all_quantized_op_inputs_[arg.Name()][0]; + auto qinput = all_quantized_op_inputs_[arg.Name()][0]->Inputs(); + auto it = std::find_if(qinput.begin(), qinput.end(), [&arg](const NodeUnitIODef& nud) { return nud.node_arg.Name() == arg.Name(); }); + bool is_conv_bias = std::distance(qinput.begin(), it) == 2; + if (!is_conv_bias || it->quant_param.has_value()) { + util::GetQuantizationScaleAndZeroPoint(graph_viewer_.GetAllInitializedTensors(), + *it, target_nodeunit->ModelPath(), + scale, zp, scales, zps); + } else if (!it->quant_param.has_value()) { + float in_scale, w_scale; + int32_t in_zp, w_zp; + std::optional> in_scales, w_scales; + std::optional> in_zps, w_zps; + + // onnx defines conv bias with non quantization, but it must be quantized in VSINPU support + // The bias scale is set as input_scale * weight_scale if per layer quantized, input_scale* weight_scale[i] if per channel quantized + util::GetQuantizationScaleAndZeroPoint(graph_viewer_.GetAllInitializedTensors(), + qinput[0], target_nodeunit->ModelPath(), + in_scale, in_zp, in_scales, in_zps); + util::GetQuantizationScaleAndZeroPoint(graph_viewer_.GetAllInitializedTensors(), + qinput[1], target_nodeunit->ModelPath(), + w_scale, w_zp, w_scales, w_zps); + scale = in_scale * w_scale; + zp = 0; + if (w_scales) { + std::vector temp; + for (size_t i = 0; i < w_scales->size(); i++) { + temp.push_back(w_scales.value()[i] * in_scale); + } + scales = temp; + } + } + } + tim::vx::Quantization quant; + // per tensor quantization + if (!scales.has_value()) { + quant.SetType(tim::vx::QuantType::ASYMMETRIC); + quant.SetScales({scale}); + quant.SetZeroPoints({zp}); + } else { // per channel quantization + if (zps.has_value()) { + bool has_nonzero = std::find_if(zps->begin(), zps->end(), [](int elem) { return elem != 0; }) != zps->end(); + if (has_nonzero && *arg.Type() == "tensor(uint8)") { + quant.SetType(tim::vx::QuantType::ASYMMETRIC_PER_CHANNEL); + } else { + quant.SetType(tim::vx::QuantType::SYMMETRIC_PER_CHANNEL); + } + quant.SetZeroPoints(zps.value()); + } else { + if (*arg.Type() == "tensor(int32)" || zp == 0) { + // set bias quant type + quant.SetType(tim::vx::QuantType::SYMMETRIC_PER_CHANNEL); + } else { + quant.SetType(tim::vx::QuantType::ASYMMETRIC_PER_CHANNEL); + } + quant.SetZeroPoints({zp}); + } + quant.SetScales(scales.value()); + quant.SetChannelDim(shape.size() - 1); + } + spec.SetQuantization(quant); + } + + std::shared_ptr tensor; + if (attribute == + tim::vx::TensorAttribute::CONSTANT) { // create const tensor + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_viewer_.GetConstantInitializer(arg.Name(), true); + std::shared_ptr unpackedTensor = + vsi::npu::util::UnpackTensor(&arg, *tensor_proto); + + const void* valueAddr = + reinterpret_cast(unpackedTensor.get()); + tensor = graph->CreateTensor(spec, valueAddr); + + } else { + tensor = graph->CreateTensor(spec); + } + for (auto& input : graph_inputs_) { + if (input->name == arg.Name()) { + input->tensor = tensor; + input->shape = vsi::npu::util::GetTensorShape(arg); + break; + } + } + for (auto& output : graph_outputs_) { + if (output->name == arg.Name()) { + output->tensor = tensor; + output->shape = utils::GetTensorShapeFromTensorShapeProto(*arg.Shape()); + break; + } + } + tensors_.insert({arg.Name(), tensor}); + return tensor; +} + +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h new file mode 100644 index 0000000000000..bd0f377b820b0 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h @@ -0,0 +1,116 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ + +#pragma once +#include +#include +#include +#include +#include +#include "builders/op_builder.h" +#include "tim/vx/context.h" +#include "tim/vx/graph.h" +#include "tim/vx/tensor.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +struct GraphIOInfo { + std::string name; + bool is_initializer; + std::shared_ptr tensor; + TensorShape shape; +}; + +struct NodeIOInfo { + std::shared_ptr op_; + std::vector input_names_; + std::vector output_names_; +}; + +class GraphEP { + public: + explicit GraphEP(const GraphViewer& graph_viewer); + ~GraphEP() {} + + bool Prepare(); + + static bool SupportedOp(const onnxruntime::GraphViewer& graph_viewer, + const NodeUnit& node_unit); + + // If a node is supported by VSINPU in a partition node group + // `node_outputs_in_group` is the set of the output names of the nodes added to this group so far + static bool IsNodeSupportedInGroup(const NodeUnit& node_unit, const GraphViewer& graph_viewer); + + const NodeUnit& GetNodeUnit(const Node* node) const; + + bool& GetCompiled() { return compiled_; } + std::shared_ptr& GetGraph() { return graph_; } + std::vector>& GetOps() { return ops_; } + std::map>& GetTensors() { + return tensors_; + } + + std::vector>& GetGraphInputs() { + return graph_inputs_; + } + + std::vector>& GetGraphOutputs() { + return graph_outputs_; + } + + void UpdateTensorMap(const std::string& name, const std::shared_ptr& dst_tensor); + + std::shared_ptr ConstructNodeIO(const std::shared_ptr& op, std::vector input_arg, std::vector output_arg); + + bool BindTensors(const std::shared_ptr& nodeio_info); + + std::shared_ptr MapTIMVXTensor( + std::shared_ptr& graph, const NodeUnitIODef nudef, + const NodeUnit& nodeunit, + const GraphViewer* graph_viewer, tim::vx::TensorAttribute attribute); + + private: + std::shared_ptr context_; + std::shared_ptr graph_; + std::map> tensors_; + std::vector> ops_; + std::vector> graph_inputs_; + std::vector> graph_outputs_; + + // Contains all quantized operators' input and the NodeUnit(s) using the input + // In the form of {input_name, [NodeUnit(s) using the input]} + std::unordered_map> all_quantized_op_inputs_; + const GraphViewer& graph_viewer_; + + // Holder for the NodeUnits in the graph, this will guarantee the NodeUnits is + // valid throughout the lifetime of the ModelBuilder + std::vector> node_unit_holder_; + std::unordered_map node_unit_map_; + bool compiled_; +}; +} // namespace npu + +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc new file mode 100644 index 0000000000000..7444dcfec09a2 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc @@ -0,0 +1,277 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include +#include +#include +#include "core/framework/compute_capability.h" +#include "core/providers/vsinpu/vsinpu_execution_provider.h" +#include "core/providers/vsinpu/vsinpu_ep_graph.h" +#include "core/providers/vsinpu/builders/op_builder.h" +#include "core/providers/vsinpu/builders/op_builder_factory.h" +#include "core/providers/vsinpu/vsinpu_util.h" +#include "core/framework/kernel_registry.h" +#include "core/framework/node_unit.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/providers/partitioning_utils.h" + +namespace onnxruntime { +VSINPUExecutionProvider::VSINPUExecutionProvider(const VSINPUExecutionProviderInfo& info) + : IExecutionProvider{onnxruntime::kVSINPUExecutionProvider}, + device_id_(info.device_id) { + AllocatorCreationInfo default_memory_info{ + [](int) { + return std::make_unique( + OrtMemoryInfo("VSINPU", OrtAllocatorType::OrtDeviceAllocator)); + }}; + + CreateAllocator(default_memory_info); + + AllocatorCreationInfo cpu_memory_info{ + [](int) { + return std::make_unique( + OrtMemoryInfo("VSINPU", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput)); + }}; + + CreateAllocator(cpu_memory_info); +} + +VSINPUExecutionProvider::~VSINPUExecutionProvider() {} + +std::vector> +VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& /*kernel_lookup*/) const { + std::vector> result; + + if (graph_viewer.IsSubgraph()) { + return result; + } + + for (const auto& tensor : graph_viewer.GetAllInitializedTensors()) { + if (tensor.second->has_data_location()) { + LOGS_DEFAULT(VERBOSE) << "location:" << tensor.second->data_location(); + if (tensor.second->data_location() == + ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + LOGS_DEFAULT(WARNING) << "VSINPU: Initializers with external data location are not " + "currently supported"; + return result; + } + } + } + // Get all the NodeUnits in the graph_viewer + std::vector> node_unit_holder; + std::unordered_map node_unit_map; + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + + // This holds the result of whether a NodeUnit is supported or not, + // to prevent nodes in a NodeUnit to be checked for multiple times + std::unordered_map node_unit_supported_result; + node_unit_supported_result.reserve(node_unit_holder.size()); + std::unordered_set node_outputs_in_current_group{}; + + const auto is_node_supported = [&](const Node& node) -> bool { + const NodeUnit* node_unit = node_unit_map.at(&node); + bool supported = false; + + // If we have visited one of the nodes in the node_unit, use the result directly + const auto it = node_unit_supported_result.find(node_unit); + if (it != node_unit_supported_result.cend()) { + supported = it->second; + } else { + // We only check the target node of the node unit + supported = vsi::npu::GraphEP::IsNodeSupportedInGroup(*node_unit, graph_viewer); + node_unit_supported_result[node_unit] = supported; + } + + LOGS_DEFAULT(VERBOSE) << "Node supported: [" << supported + << "] Operator type: [" << node.OpType() + << "] index: [" << node.Index() + << "] name: [" << node.Name() + << "] as part of the NodeUnit type: [" << node_unit->OpType() + << "] index: [" << node_unit->Index() + << "] name: [" << node_unit->Name() + << "]"; + + if (supported) { + // We want to save all the output names of nodes in the current group for easy query + for (const auto* output : node.OutputDefs()) { + node_outputs_in_current_group.insert(output->Name()); + } + } + return supported; + }; + + const auto on_group_closed = [&](const std::vector& group) -> bool { + // reset per-partition node group tracking + node_outputs_in_current_group.clear(); + return true; + }; + + const auto gen_metadef_name = [&]() { + static size_t group_counter = 0; + return "VSINPU_" + std::to_string(++group_counter); + }; + result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed, + gen_metadef_name, "VSINPU", kVSINPUExecutionProvider, &node_unit_map); + std::for_each(result.begin(), result.end(), [&graph_viewer](auto& capability) { + if (capability && capability->sub_graph && capability->sub_graph->GetMetaDef()) { + const auto* meta_def = capability->sub_graph->GetMetaDef(); + bool has_any_non_constant_inputs = std::any_of(meta_def->inputs.begin(), meta_def->inputs.end(), [&graph_viewer](const auto& input) { + return !graph_viewer.IsConstantInitializer(input, true); + }); + + // ALL inputs are constant + if (!has_any_non_constant_inputs) { + capability.reset(); + } + } + }); + + const auto num_of_partitions = result.size(); + const auto num_of_supported_nodes = std::accumulate( + result.begin(), result.end(), size_t{0}, + [](const auto& acc, const auto& partition) -> size_t { + return acc + (partition && partition->sub_graph ? partition->sub_graph->nodes.size() : 0); + }); + + const auto summary_msg = MakeString( + "VSINPUExecutionProvider::GetCapability,", + " number of partitions supported by VSINPU: ", num_of_partitions, + "; number of nodes in the graph: ", graph_viewer.NumberOfNodes(), + "; number of nodes supported by VSINPU: ", num_of_supported_nodes); + + // If the graph is partitioned in multiple subgraphs, and this may impact performance, + // we want to give users a summary message at warning level. + if (num_of_partitions > 1) { + LOGS_DEFAULT(WARNING) << summary_msg; + } else { + LOGS_DEFAULT(INFO) << summary_msg; + } + + return result; +} + +Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, + OrtKernelContext* context) { + Ort::KernelContext ctx(context); + size_t num_in = ctx.GetInputCount(); + const size_t num_inputs = graph_ep->GetGraphInputs().size(); + + for (size_t i = 0, j = 0; i < num_inputs; i++) { + if (!graph_ep->GetGraphInputs()[i]->is_initializer) { + const auto onnx_input_tensor = ctx.GetInput(i); + const auto tensor_info = onnx_input_tensor.GetTensorTypeAndShapeInfo(); + + auto origin_tensor = graph_ep->GetGraphInputs()[i]->tensor; + origin_tensor->CopyDataToTensor(onnx_input_tensor.GetTensorRawData(), vsi::npu::util::GetTensorBytes(tensor_info)); + j++; + } + } + + if (!graph_ep->GetGraph()->Run()) { + LOGS_DEFAULT(ERROR) << "Failed to run graph."; + } + for (size_t i = 0; i < ctx.GetOutputCount(); i++) { + auto timvx_tensor = graph_ep->GetGraphOutputs()[i]->tensor; + auto out_shape = graph_ep->GetGraphOutputs()[i]->shape.GetDims(); + auto onnx_output_tensor = + ctx.GetOutput(i, out_shape.data(), out_shape.size()); + timvx_tensor->CopyDataFromTensor(const_cast(onnx_output_tensor.GetTensorRawData())); + } + + return Status::OK(); +} + +Status VSINPUExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) { + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + const GraphViewer& graph_viewer = fused_node_graph.filtered_graph; + std::shared_ptr graph_ep = std::make_shared(graph_viewer); + + for (auto tensor : graph_viewer.GetInputsIncludingInitializers()) { + LOGS_DEFAULT(VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#" + << graph_viewer.IsInitializedTensor(tensor->Name()); + auto input = std::make_shared(); + input->name = tensor->Name(); + input->is_initializer = graph_viewer.IsConstantInitializer(tensor->Name(), true); + graph_ep->GetGraphInputs().push_back(input); + } + for (auto tensor : graph_viewer.GetOutputs()) { + LOGS_DEFAULT(VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor); + auto output = std::make_shared(); + output->name = tensor->Name(); + output->is_initializer = false; + graph_ep->GetGraphOutputs().push_back(output); + } + + auto node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto& node_index : node_indices) { + const auto node = graph_viewer.GetNode(node_index); + const NodeUnit& node_unit = graph_ep->GetNodeUnit(node); + + // Only add op when we hit the target node + if (node != &node_unit.GetNode()) { + continue; + } + LOGS_DEFAULT(VERBOSE) << "Adding node: [" << node->OpType() << "]"; + vsi::npu::SupportedBuiltinOps().at(node->OpType())->BuildOp(graph_ep.get(), graph_viewer, node_unit); + } + + LOGS_DEFAULT(INFO) << "Verifying graph"; + graph_ep->GetCompiled() = graph_ep->GetGraph()->Compile(); + if (!graph_ep->GetCompiled()) { + LOGS_DEFAULT(ERROR) << "Failed to verify graph."; + } else { + LOGS_DEFAULT(INFO) << "Graph has been verified successfully."; + } + + NodeComputeInfo compute_info; + compute_info.create_state_func = [graph_ep](ComputeContext* /*context*/, + FunctionState* state) { + *state = graph_ep.get(); + return 0; + }; + + compute_info.compute_func = + [graph_ep, this](FunctionState /*state*/, const OrtApi* /* api */, + OrtKernelContext* context) { + std::lock_guard lock(this->GetMutex()); + Status res = ComputeStateFunc(graph_ep.get(), context); + return res; + }; + + compute_info.release_state_func = [](FunctionState /*state*/) {}; + + node_compute_funcs.push_back(compute_info); + } + + return Status::OK(); +} + +std::shared_ptr VSINPUExecutionProvider::GetKernelRegistry() const { + static std::shared_ptr kernel_registry = std::make_shared(); + return kernel_registry; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h new file mode 100644 index 0000000000000..44318c332fdd0 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h @@ -0,0 +1,53 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include "core/framework/execution_provider.h" +#include "core/session/abi_session_options_impl.h" + +namespace onnxruntime { +struct VSINPUExecutionProviderInfo { + int device_id{0}; +}; + +class VSINPUExecutionProvider : public IExecutionProvider { + public: + explicit VSINPUExecutionProvider(const VSINPUExecutionProviderInfo& info); + virtual ~VSINPUExecutionProvider(); + + std::vector> GetCapability( + const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& kernel_lookup) const override; + std::shared_ptr GetKernelRegistry() const override; + Status Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) override; + OrtMutex& GetMutex() { return mutex_; } + + private: + int device_id_; + OrtMutex mutex_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.cc b/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.cc new file mode 100644 index 0000000000000..5f2f961d95c09 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.cc @@ -0,0 +1,59 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#include "core/framework/compute_capability.h" +#include "core/providers/vsinpu/vsinpu_provider_factory.h" +#include "core/providers/vsinpu/vsinpu_provider_factory_creator.h" +#include "core/providers/vsinpu/vsinpu_execution_provider.h" + +namespace onnxruntime { + +struct VSINPUProviderFactory : IExecutionProviderFactory { + VSINPUProviderFactory() {} + ~VSINPUProviderFactory() override {} + + std::unique_ptr CreateProvider() override; +}; + +std::unique_ptr VSINPUProviderFactory::CreateProvider() { + onnxruntime::VSINPUExecutionProviderInfo info; + return std::make_unique(info); +} + +std::shared_ptr CreateExecutionProviderFactory_VSINPU() { + return std::make_shared(); +} + +std::shared_ptr +VSINPUProviderFactoryCreator::Create() { + return std::make_shared(); +} + +} // namespace onnxruntime + +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_VSINPU, + _In_ OrtSessionOptions* options) { + options->provider_factories.push_back( + onnxruntime::VSINPUProviderFactoryCreator::Create()); + return nullptr; +} diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory_creator.h b/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory_creator.h new file mode 100644 index 0000000000000..e69185c0df816 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory_creator.h @@ -0,0 +1,34 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once + +#include + +#include "core/providers/providers.h" + +namespace onnxruntime { +struct VSINPUProviderFactoryCreator { + static std::shared_ptr Create(); +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_util.cc b/onnxruntime/core/providers/vsinpu/vsinpu_util.cc new file mode 100644 index 0000000000000..8008ec1f436a4 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_util.cc @@ -0,0 +1,502 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ + +#include +#include +#include +#include +#include "core/providers/vsinpu/vsinpu_util.h" + +#include "core/optimizer/initializer.h" +#include "core/providers/shared/utils/utils.h" +namespace onnxruntime { + +template +struct shared_array_deletor { + void operator()(T const* ptr) { delete[] ptr; } +}; +namespace vsi { +namespace npu { +namespace util { +tim::vx::DataType OnnxDtypeToTIMVXDtype(const int32_t dtype) { + switch (dtype) { + case onnx::TensorProto_DataType_FLOAT: + return tim::vx::DataType::FLOAT32; + case onnx::TensorProto_DataType_FLOAT16: + return tim::vx::DataType::FLOAT16; + case onnx::TensorProto_DataType_INT8: + return tim::vx::DataType::INT8; + case onnx::TensorProto_DataType_UINT8: + return tim::vx::DataType::UINT8; + case onnx::TensorProto_DataType_INT32: + return tim::vx::DataType::INT32; + case onnx::TensorProto_DataType_INT16: + return tim::vx::DataType::INT16; + case onnx::TensorProto_DataType_UINT16: + return tim::vx::DataType::UINT16; + case onnx::TensorProto_DataType_BOOL: + return tim::vx::DataType::BOOL8; + default: + LOGS_DEFAULT(WARNING) << "Unsupported data type: " << dtype; + break; + } + return tim::vx::DataType::FLOAT32; +} + +tim::vx::DataType OnnxDtypeToTIMVXDtype(const ONNX_NAMESPACE::DataType type) { + static const std::map type_table = { + {"tensor(float)", tim::vx::DataType::FLOAT32}, + {"tensor(float16)", tim::vx::DataType::FLOAT16}, + {"tensor(int8)", tim::vx::DataType::INT8}, + {"tensor(uint8)", tim::vx::DataType::UINT8}, + {"tensor(int32)", tim::vx::DataType::INT32}, + {"tensor(int16)", tim::vx::DataType::INT16}, + {"tensor(uint16)", tim::vx::DataType::UINT16}, + {"tensor(int64)", tim::vx::DataType::INT64}, + {"tensor(bool)", tim::vx::DataType::BOOL8}, + }; + auto search = type_table.find(*type); + if (search != type_table.end()) { + return search->second; + } + LOGS_DEFAULT(WARNING) << "Unsupported data type: " << *type; + return tim::vx::DataType::FLOAT32; +} + +tim::vx::ShapeType OnnxShapeToTIMVXShape(const onnxruntime::TensorShape& ts) { + tim::vx::ShapeType timvx_shape(ts.NumDimensions()); + if (ts.NumDimensions() == 0) { + timvx_shape.push_back(1); + } else { + for (size_t i = 0; i < ts.NumDimensions(); i++) { + timvx_shape[i] = ts.GetDims()[i]; + } + } + return timvx_shape; +} + +std::string PrintNode(const onnxruntime::NodeArg& node_arg) { + auto shape = node_arg.Shape(); + if (shape == nullptr) { + return ""; + } + std::string s = node_arg.Name() + ":<"; + if (shape->dim_size() == 0) { + s += "1>, is a scalar"; + return s; + } + for (int i = 0; i < shape->dim_size(); i++) { + auto dim = shape->dim(i); + std::string s1; + std::stringstream ss; + ss << dim.dim_value(); + ss >> s1; + s += s1; + if (i < shape->dim_size() - 1) { + s += ","; + } else { + s += ">"; + } + } + return s; +} + +std::string PrintNode(const std::vector shape) { + if (shape.size() == 0) { + return ""; + } + std::string s = "<"; + for (std::size_t i = 0; i < shape.size(); i++) { + auto dim = shape[i]; + std::string s1; + std::stringstream ss; + ss << dim; + ss >> s1; + s += s1; + if (i < shape.size() - 1) { + s += ","; + } else { + s += ">"; + } + } + return s; +} + +size_t GetTensorElementSize(const ONNXTensorElementDataType type) { + switch (type) { + case onnx::TensorProto_DataType_INT64: + return 8; + case onnx::TensorProto_DataType_FLOAT: + case onnx::TensorProto_DataType_INT32: + return 4; + case onnx::TensorProto_DataType_FLOAT16: + case onnx::TensorProto_DataType_INT16: + case onnx::TensorProto_DataType_UINT16: + return 2; + case onnx::TensorProto_DataType_INT8: + case onnx::TensorProto_DataType_UINT8: + case onnx::TensorProto_DataType_BOOL: + return 1; + default: + break; + } + return 0; +} + +size_t GetTensorBytes(const Ort::TensorTypeAndShapeInfo& info) { + return info.GetElementCount() * GetTensorElementSize(info.GetElementType()); +} + +TensorShape GetTensorShape(const onnxruntime::NodeArg& node_arg) { + auto shape_proto = node_arg.Shape(); + std::vector dims; + if (shape_proto != nullptr) { + for (int i = 0; i < shape_proto->dim_size(); i++) { + auto dim = shape_proto->dim(i); + dims.push_back(dim.dim_value()); + } + } + if (dims.size() == 0) { + dims.push_back(1); + } + TensorShape ts(dims); + return ts; +} + +std::shared_ptr UnpackTensor( + const NodeArg* node_arg, const ONNX_NAMESPACE::TensorProto& initializer) { + std::shared_ptr unpackedTensor; + auto shape = GetTensorShape(*node_arg); + size_t elementCount = shape.Size(); + +#define CASE_PROTO(X, Y) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: { \ + size_t tensorByteSize = elementCount * sizeof(Y); \ + unpackedTensor.reset(new uint8_t[tensorByteSize], \ + shared_array_deletor()); \ + auto status = onnxruntime::utils::UnpackTensor( \ + initializer, \ + initializer.has_raw_data() ? initializer.raw_data().data() : nullptr, \ + initializer.has_raw_data() ? initializer.raw_data().size() : 0, \ + reinterpret_cast(unpackedTensor.get()), elementCount); \ + if (!status.IsOK()) { \ + LOGS_DEFAULT(ERROR) << "Unpack tensor data failed."; \ + } \ + break; \ + } + switch (initializer.data_type()) { + CASE_PROTO(FLOAT, float); + CASE_PROTO(DOUBLE, double); + CASE_PROTO(BOOL, bool); + CASE_PROTO(INT8, int8_t); + CASE_PROTO(INT16, int16_t); + CASE_PROTO(INT32, int32_t); + CASE_PROTO(INT64, int64_t); + CASE_PROTO(UINT8, uint8_t); + CASE_PROTO(UINT16, uint16_t); + CASE_PROTO(UINT32, uint32_t); + CASE_PROTO(FLOAT16, onnxruntime::MLFloat16); + default: + return nullptr; + } + + return unpackedTensor; +} + +tim::vx::PadType GetPadType(const std::string type) { + static const std::map type_table = { + {"NOTSET", tim::vx::PadType::AUTO}, + {"SAME_UPPER", tim::vx::PadType::SAME}, + {"SAME_LOWER", tim::vx::PadType::SAME}, + {"VALID", tim::vx::PadType::VALID}, + }; + auto search = type_table.find(type); + if (search != type_table.end()) { + return search->second; + } + return tim::vx::PadType::NONE; +} + +int32_t ReverseAxis(int32_t origin_axis, int32_t length) { + int32_t axis = 0; + if (origin_axis < 0) { + origin_axis += length; + } + axis = length - origin_axis - 1; + return axis; +} + +std::vector ReverseAxis(std::vector origin_axes, int32_t length) { + std::vector axes; + for (int32_t& axis : origin_axes) { + if (axis < 0) { + axis += length; + } + axes.push_back(length - axis - 1); + } + std::sort(axes.begin(), axes.end()); + return axes; +} + +bool IsTypeSupported(const NodeArg* node_arg) { + const auto* type_proto = node_arg->TypeAsProto(); + if (!type_proto) { + return false; + } + + switch (type_proto->tensor_type().elem_type()) { + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64: + return true; + default: + return false; + } +} + +QuantizedOpType GetQuantizedOpType(const NodeUnit& node_unit) { + const auto& op_type = node_unit.OpType(); + if (node_unit.UnitType() == NodeUnit::Type::SingleNode) { + if (op_type == "DequantizeLinear") + return QuantizedOpType::DequantizeLinear; + else if (op_type == "QuantizeLinear") + return QuantizedOpType::QuantizeLinear; + else if (op_type == "QLinearConv") + return QuantizedOpType::QLinearConv; + else if (op_type == "QLinearMatMul") + return QuantizedOpType::QLinearMatMul; + else if (op_type == "QLinearAdd") + return QuantizedOpType::QLinearAdd; + else if (op_type == "QLinearMul") + return QuantizedOpType::QLinearMul; + else if (op_type == "QLinearSigmoid") + return QuantizedOpType::QLinearSigmoid; + else if (op_type == "QLinearAveragePool") + return QuantizedOpType::QLinearAveragePool; + } else if (node_unit.UnitType() == NodeUnit::Type::QDQGroup) { + if (op_type == "Conv") + return QuantizedOpType::QDQConv; + else if (op_type == "Resize") + return QuantizedOpType::QDQResize; + else if (op_type == "AveragePool") + return QuantizedOpType::QDQAveragePool; + else if (op_type == "Add") + return QuantizedOpType::QDQAdd; + else if (op_type == "Mul") + return QuantizedOpType::QDQMul; + else if (op_type == "Transpose") + return QuantizedOpType::QDQTranspose; + else if (op_type == "Reshape") + return QuantizedOpType::QDQReshape; + else if (op_type == "Softmax") + return QuantizedOpType::QDQSoftmax; + else if (op_type == "Concat") + return QuantizedOpType::QDQConcat; + else if (op_type == "Gemm") + return QuantizedOpType::QDQGemm; + else if (op_type == "MatMul") + return QuantizedOpType::QDQMatMul; + } + return QuantizedOpType::Unknown; +} + +ConvType GetConvType(const NodeUnit& node_unit, const InitializedTensorSet& initializers) { + NodeAttrHelper helper(node_unit); + const auto group = helper.Get("group", 1); + + const auto& weight = node_unit.Inputs()[1].node_arg.Name(); + const auto& weight_tensor = *initializers.at(weight); + + // For ONNX we only have 1 conv ops + // For VSINPU we have 3 + // Input is (W, H, C, N) + // group == 1, --> regular conv + // group != 1 && weight is (kW, kH, group, M), --> depthwise conv + // group != 1 && weight is (kW, kH, C/group, M), --> grouped conv + if (group == 1) + return ConvType::Regular; + else if ((weight_tensor.dims()[1] == group)) + return ConvType::Depthwise; + else + return ConvType::Grouped; +} + +bool IsQuantizedConv(QuantizedOpType quant_op_type) { + return (quant_op_type == QuantizedOpType::QLinearConv) || + (quant_op_type == QuantizedOpType::QDQConv); +} + +bool IsQuantizedPool(QuantizedOpType quant_op_type) { + return (quant_op_type == QuantizedOpType::QLinearAveragePool) || + (quant_op_type == QuantizedOpType::QDQAveragePool); +} + +bool IsQuantizedGemm(QuantizedOpType quant_op_type) { + return (quant_op_type == QuantizedOpType::QLinearMatMul) || + (quant_op_type == QuantizedOpType::QDQGemm) || + (quant_op_type == QuantizedOpType::QDQMatMul); +} + +bool IsQuantizedBinaryOp(QuantizedOpType quant_op_type) { + return quant_op_type == QuantizedOpType::QLinearMatMul || + quant_op_type == QuantizedOpType::QLinearAdd || + quant_op_type == QuantizedOpType::QLinearMul || + quant_op_type == QuantizedOpType::QDQAdd || + quant_op_type == QuantizedOpType::QDQMul || + quant_op_type == QuantizedOpType::QDQGemm || + quant_op_type == QuantizedOpType::QDQMatMul || + IsQuantizedConv(quant_op_type); +} + +bool HasValidBinaryOpQuantizedInputTypes(const NodeUnit& node_unit) { + auto quant_op_type = GetQuantizedOpType(node_unit); + int32_t a_input_type, b_input_type; + if (!IsQuantizedBinaryOp(quant_op_type)) { + LOGS_DEFAULT(VERBOSE) << "[" << node_unit.OpType() << "] is not a binary qlinear op"; + return false; + } + + const auto& inputs = node_unit.Inputs(); + if (!GetType(inputs[0].node_arg, a_input_type)) + return false; + if (!GetType(inputs[1].node_arg, b_input_type)) + return false; + + // QlinearConv/MatMul/QDQGemm/QDQMatMul supports u8u8 or u8s8 + // QLinearAdd/QLinearMul only support u8u8 + bool is_quant_conv_or_gemm = IsQuantizedConv(quant_op_type) || IsQuantizedGemm(quant_op_type); + + bool has_valid_qlinear_conv_weight = + (b_input_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8 || + b_input_type == ONNX_NAMESPACE::TensorProto_DataType_INT8); + + bool has_valid_qlinear_conv_input = + (a_input_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8 || + a_input_type == ONNX_NAMESPACE::TensorProto_DataType_INT8); + + if ((is_quant_conv_or_gemm && !has_valid_qlinear_conv_weight) || + (!is_quant_conv_or_gemm && a_input_type != b_input_type)) { + LOGS_DEFAULT(VERBOSE) << "[" << node_unit.OpType() + << "] A Input type: [" << a_input_type + << "] B Input type: [" << b_input_type + << "] is not supported for now"; + return false; + } + + return true; +} + +void GetQuantizationScaleAndZeroPoint( + const InitializedTensorSet& initializers, const NodeUnitIODef& io_def, const Path& model_path, + float& scale, int32_t& zero_point, std::optional>& pcq_scales, + std::optional>& pcq_zps) { + scale = 0.0f; + zero_point = 0; + + const auto& quant_param = *io_def.quant_param; + { // get the scale + const auto& name = quant_param.scale.Name(); + Initializer unpacked_tensor(*initializers.at(name), model_path); + scale = unpacked_tensor.DataAsSpan()[0]; + + // per channel quantized handling + if (!unpacked_tensor.dims().empty() && unpacked_tensor.dims()[0] != 0 && unpacked_tensor.dims()[0] != 1) { + auto scales = unpacked_tensor.DataAsSpan(); + std::vector scales_vec(scales.begin(), scales.end()); + pcq_scales = onnxruntime::make_optional(std::move(scales_vec)); + } + } + + if (quant_param.zero_point) { // get the zero point if it exists + const auto& name = quant_param.zero_point->Name(); + Initializer unpacked_tensor(*initializers.at(name), model_path); + bool is_i8_zp = unpacked_tensor.data_type() == onnx::TensorProto_DataType_INT8; + // some qdq conv bias is int32 quantized + bool is_int32_zp = unpacked_tensor.data_type() == onnx::TensorProto_DataType_INT32; + zero_point = is_i8_zp ? static_cast(unpacked_tensor.DataAsSpan()[0]) : is_int32_zp ? static_cast(unpacked_tensor.DataAsSpan()[0]) + : static_cast(unpacked_tensor.DataAsByteSpan()[0]); + + // per channel quantized handling + if (!unpacked_tensor.dims().empty() && unpacked_tensor.dims()[0] != 0 && unpacked_tensor.dims()[0] != 1) { + auto type = unpacked_tensor.data_type(); + if (is_i8_zp) { + auto zps = unpacked_tensor.DataAsSpan(); + std::vector zps_vec(zps.begin(), zps.end()); + pcq_zps = onnxruntime::make_optional(std::move(zps_vec)); + } else if (is_int32_zp) { + auto zps = unpacked_tensor.DataAsByteSpan(); + std::vector zps_vec(zps.begin(), zps.end()); + pcq_zps = onnxruntime::make_optional(std::move(zps_vec)); + } else { + auto zps = unpacked_tensor.DataAsSpan(); + std::vector zps_vec(zps.begin(), zps.end()); + pcq_zps = onnxruntime::make_optional(std::move(zps_vec)); + } + } + } +} + +static bool IsInternalQuantizedNodeUnit(const NodeUnit& node_unit) { + // First, ignore QDQ NodeUnit which is not internal quantized node + if (node_unit.UnitType() == NodeUnit::Type::QDQGroup) + return false; + + // These operators can use uint8 input without specific QLinear version of it + // However, the mode has to be internal to the graph/partition (they cannot consume graph inputs) + static const std::unordered_set internal_quantized_op_types = { + "Transpose", + "Resize", + "Concat", + "MaxPool", + }; + + const auto& node = node_unit.GetNode(); + if (!Contains(internal_quantized_op_types, node.OpType())) + return false; + + int32_t input_type; + ORT_ENFORCE(GetType(*node.InputDefs()[0], input_type)); + + return input_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8 || input_type == ONNX_NAMESPACE::TensorProto_DataType_INT8; +} + +bool GetType(const NodeArg& node_arg, int32_t& type) { + type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) { + LOGS_DEFAULT(WARNING) << "NodeArg [" << node_arg.Name() << "] has no input type"; + return false; + } + + type = type_proto->tensor_type().elem_type(); + return true; +} +} // namespace util +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_util.h b/onnxruntime/core/providers/vsinpu/vsinpu_util.h new file mode 100644 index 0000000000000..9ec580bf02e77 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/vsinpu_util.h @@ -0,0 +1,131 @@ +/**************************************************************************** + * + * Copyright (c) 2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ + +#pragma once +#include +#include +#include +#include "core/framework/op_kernel.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/framework/tensorprotoutils.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/framework/node_unit.h" +#include "tim/vx/tensor.h" +#include "tim/vx/types.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { +namespace util { + +tim::vx::DataType OnnxDtypeToTIMVXDtype(const int32_t dtype); + +tim::vx::DataType OnnxDtypeToTIMVXDtype(const ONNX_NAMESPACE::DataType type); + +tim::vx::ShapeType OnnxShapeToTIMVXShape(const onnxruntime::TensorShape& ts); + +std::string PrintNode(const onnxruntime::NodeArg& node_arg); + +std::string PrintNode(const std::vector shape); + +size_t GetTensorElementSize(const ONNXTensorElementDataType type); + +size_t GetTensorBytes(const Ort::TensorTypeAndShapeInfo& info); + +TensorShape GetTensorShape(const onnxruntime::NodeArg& node_arg); + +std::shared_ptr UnpackTensor( + const NodeArg* node, const ONNX_NAMESPACE::TensorProto& initializer); + +tim::vx::PadType GetPadType(const std::string type); + +int32_t ReverseAxis(int32_t origin_axis, int32_t length); + +std::vector ReverseAxis(std::vector origin_axes, int32_t length); + +bool IsTypeSupported(const NodeArg* node_arg); + +enum class QuantizedOpType : uint8_t { + Unknown, // Unknown or not a quantized NodeUnit + DequantizeLinear, + QuantizeLinear, + QLinearConv, + QLinearMatMul, + QLinearAdd, + QLinearSigmoid, + QLinearAveragePool, + QLinearMul, + // Not yet supported + // QLinearReduceMean, + QDQConv, + QDQResize, + QDQAveragePool, + QDQAdd, + QDQMul, + QDQTranspose, + QDQReshape, + QDQSoftmax, + QDQConcat, + QDQGemm, + QDQMatMul, + // TODO(cfy) :Add other QDQ NodeUnit types +}; + +enum class ConvType : uint8_t { + Regular, + Depthwise, + Grouped, +}; +QuantizedOpType GetQuantizedOpType(const NodeUnit& node_unit); + +ConvType GetConvType(const NodeUnit& node_unit, const InitializedTensorSet& initializers); + +// If this is a quantized Conv (QLinearConv or QDQConv) +bool IsQuantizedConv(QuantizedOpType quant_op_type); + +// If this is a quantized Pool (QLinearAveragePool or QDQAveragePool) +bool IsQuantizedPool(QuantizedOpType quant_op_type); + +// If this is a quantized Gemm (QLinearMatMul or QDQMatMul/QDQGemm) +bool IsQuantizedGemm(QuantizedOpType quant_op_type); + +// This quantized op is an operator or qdq node unit takes 2 inputs and produces 1 output +// Such as QLinearConv, QLinearMatMul, QLinearAdd, QDQConv,... +bool IsQuantizedBinaryOp(QuantizedOpType quant_op_type); + +// Check if a qlinear binary op has valid inputs, Qlinear[Conv/MatMul/Add] +bool HasValidBinaryOpQuantizedInputTypes(const NodeUnit& node_unit); + +void GetQuantizationScaleAndZeroPoint( + const InitializedTensorSet& initializers, const NodeUnitIODef& io_def, const Path& model_path, + float& scale, int32_t& zero_point, + std::optional>& pcq_scales, + std::optional>& pcq_zps); + +bool GetType(const NodeArg& node_arg, int32_t& type); + +} // namespace util +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 6d3e9c2cb7865..3319fdd34646b 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -796,7 +796,7 @@ void LoadTests(const std::vector>& input_paths auto test_case_dir = model_info->GetDir(); auto test_case_name_in_log = test_case_name + ORT_TSTR(" in ") + test_case_dir.native(); -#if !defined(ORT_MINIMAL_BUILD) && !defined(USE_QNN) +#if !defined(ORT_MINIMAL_BUILD) && !defined(USE_QNN) && !defined(USE_VSINPU) // to skip some models like *-int8 or *-qdq if ((reinterpret_cast(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || (reinterpret_cast(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) { diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 0356bf5218cc2..fc29756a1ff98 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -44,7 +44,7 @@ void usage() { "\t-r [repeat]: Specifies the number of times to repeat\n" "\t-v: verbose\n" "\t-n [test_case_name]: Specifies a single test case to run.\n" - "\t-e [EXECUTION_PROVIDER]: EXECUTION_PROVIDER could be 'cpu', 'cuda', 'dnnl', 'tensorrt', " + "\t-e [EXECUTION_PROVIDER]: EXECUTION_PROVIDER could be 'cpu', 'cuda', 'dnnl', 'tensorrt', 'vsinpu'" "'openvino', 'rocm', 'migraphx', 'acl', 'armnn', 'xnnpack', 'nnapi', 'qnn', 'snpe' or 'coreml'. " "Default: 'cpu'.\n" "\t-p: Pause after launch, can attach debugger and continue\n" @@ -169,6 +169,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { bool enable_mem_pattern = true; bool enable_qnn = false; bool enable_nnapi = false; + bool enable_vsinpu = false; bool enable_coreml = false; bool enable_snpe = false; bool enable_dml = false; @@ -248,6 +249,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) { enable_qnn = true; } else if (!CompareCString(optarg, ORT_TSTR("nnapi"))) { enable_nnapi = true; + } else if (!CompareCString(optarg, ORT_TSTR("vsinpu"))) { + enable_vsinpu = true; } else if (!CompareCString(optarg, ORT_TSTR("coreml"))) { enable_coreml = true; } else if (!CompareCString(optarg, ORT_TSTR("snpe"))) { @@ -561,6 +564,14 @@ int real_main(int argc, char* argv[], Ort::Env& env) { #else fprintf(stderr, "NNAPI is not supported in this build"); return -1; +#endif + } + if (enable_vsinpu) { +#ifdef USE_VSINPU + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_VSINPU(sf)); +#else + fprintf(stderr, "VSINPU is not supported in this build"); + return -1; #endif } if (enable_coreml) { diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 175079d8197bf..b7c99fa66a1ea 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -261,6 +261,8 @@ static bool ParseSessionConfigs(const std::string& configs_string, test_config.machine_config.provider_type_name = onnxruntime::kSnpeExecutionProvider; } else if (!CompareCString(optarg, ORT_TSTR("nnapi"))) { test_config.machine_config.provider_type_name = onnxruntime::kNnapiExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("vsinpu"))) { + test_config.machine_config.provider_type_name = onnxruntime::kVSINPUExecutionProvider; } else if (!CompareCString(optarg, ORT_TSTR("coreml"))) { test_config.machine_config.provider_type_name = onnxruntime::kCoreMLExecutionProvider; } else if (!CompareCString(optarg, ORT_TSTR("dml"))) { diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 1485a4456d326..ff782da35cbe6 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -397,6 +397,12 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options, nnapi_flags)); #else ORT_THROW("NNAPI is not supported in this build\n"); +#endif + } else if (provider_name_ == onnxruntime::kVSINPUExecutionProvider) { +#ifdef USE_VSINPU + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_VSINPU(session_options)); +#else + ORT_THROW("VSINPU is not supported in this build\n"); #endif } else if (provider_name_ == onnxruntime::kCoreMLExecutionProvider) { #ifdef __APPLE__ diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 8d84c689cd23e..1db8616c85daa 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -428,6 +428,7 @@ bool SetEpsForAllNodes(Graph& graph, if (provider_type == onnxruntime::kOpenVINOExecutionProvider || provider_type == onnxruntime::kTensorrtExecutionProvider || provider_type == onnxruntime::kNnapiExecutionProvider || + provider_type == onnxruntime::kVSINPUExecutionProvider || provider_type == onnxruntime::kCoreMLExecutionProvider || provider_type == onnxruntime::kDnnlExecutionProvider || provider_type == onnxruntime::kQnnExecutionProvider || @@ -649,6 +650,7 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, kAclExecutionProvider, kArmNNExecutionProvider, kNnapiExecutionProvider, + kVSINPUExecutionProvider, kRocmExecutionProvider, kCoreMLExecutionProvider, kCoreMLExecutionProviderMLProgram, @@ -688,6 +690,8 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, execution_provider = DefaultTensorrtExecutionProvider(); else if (provider_type == onnxruntime::kNnapiExecutionProvider) execution_provider = DefaultNnapiExecutionProvider(); + else if (provider_type == onnxruntime::kVSINPUExecutionProvider) + execution_provider = DefaultVSINPUExecutionProvider(); else if (provider_type == onnxruntime::kRknpuExecutionProvider) execution_provider = DefaultRknpuExecutionProvider(); else if (provider_type == onnxruntime::kAclExecutionProvider) diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index dcb592a4a254e..cb9887314eb66 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -30,6 +30,10 @@ #include "core/providers/nnapi/nnapi_provider_factory.h" #endif +#ifdef USE_VSINPU +#include "core/providers/vsinpu/vsinpu_provider_factory.h" +#endif + #ifdef USE_RKNPU #include "core/providers/rknpu/rknpu_provider_factory.h" #endif @@ -238,6 +242,11 @@ TEST_P(ModelTest, Run) { ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_Nnapi(ortso, 0)); } #endif +#ifdef USE_VSINPU + else if (provider_name == "vsinpu") { + ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_VSINPU(ortso)); + } +#endif #ifdef USE_RKNPU else if (provider_name == "rknpu") { ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_Rknpu(ortso)); @@ -406,6 +415,9 @@ static constexpr ORT_STRING_VIEW provider_name_dnnl = ORT_TSTR("dnnl"); #if defined(USE_NNAPI) && defined(__ANDROID__) static constexpr ORT_STRING_VIEW provider_name_nnapi = ORT_TSTR("nnapi"); #endif +#ifdef USE_VSINPU +static ORT_STRING_VIEW provider_name_vsinpu = ORT_TSTR("vsinpu"); +#endif #ifdef USE_RKNPU static constexpr ORT_STRING_VIEW provider_name_rknpu = ORT_TSTR("rknpu"); #endif @@ -447,6 +459,9 @@ ::std::vector<::std::basic_string> GetParameterStrings() { #if defined(USE_NNAPI) && defined(__ANDROID__) provider_names[provider_name_nnapi] = {opset7, opset8, opset9, opset10, opset11, opset12, opset13, opset14, opset15, opset16, opset17, opset18}; #endif +#ifdef USE_VSINPU + provider_names[provider_name_vsinpu] = {}; +#endif #ifdef USE_RKNPU provider_names[provider_name_rknpu] = {}; #endif diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index f15ac100f4e3f..312aa86277994 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -190,6 +190,14 @@ std::unique_ptr DefaultNnapiExecutionProvider() { #endif } +std::unique_ptr DefaultVSINPUExecutionProvider() { +#if defined(USE_VSINPU) + return VSINPUProviderFactoryCreator::Create()->CreateProvider(); +#else + return nullptr; +#endif +} + std::unique_ptr DefaultRknpuExecutionProvider() { #ifdef USE_RKNPU return RknpuProviderFactoryCreator::Create()->CreateProvider(); diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index ae8e89c386994..606dfc068d399 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -20,6 +20,7 @@ std::shared_ptr CreateExecutionProviderFactory_MIGrap std::shared_ptr CreateExecutionProviderFactory_Nnapi( uint32_t flags, const optional& partitioning_stop_ops_list); // std::shared_ptr CreateExecutionProviderFactory_Tvm(const char*); +std::shared_ptr CreateExecutionProviderFactory_VSINPU(); std::shared_ptr CreateExecutionProviderFactory_Rknpu(); std::shared_ptr CreateExecutionProviderFactory_Rocm(const OrtROCMProviderOptions* provider_options); std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* params); @@ -50,6 +51,7 @@ std::unique_ptr MIGraphXExecutionProviderWithOptions(const O std::unique_ptr OpenVINOExecutionProviderWithOptions(const OrtOpenVINOProviderOptions* params); std::unique_ptr DefaultOpenVINOExecutionProvider(); std::unique_ptr DefaultNnapiExecutionProvider(); +std::unique_ptr DefaultVSINPUExecutionProvider(); std::unique_ptr DefaultRknpuExecutionProvider(); std::unique_ptr DefaultAclExecutionProvider(bool enable_arena = true); std::unique_ptr DefaultArmNNExecutionProvider(bool enable_arena = true); diff --git a/onnxruntime/test/util/include/providers.h b/onnxruntime/test/util/include/providers.h index aa489e6cd958b..a73b237ae10df 100644 --- a/onnxruntime/test/util/include/providers.h +++ b/onnxruntime/test/util/include/providers.h @@ -16,6 +16,9 @@ #ifdef USE_NNAPI #include "core/providers/nnapi/nnapi_provider_factory.h" #endif +#ifdef USE_VSINPU +#include "core/providers/vsinpu/vsinpu_provider_factory.h" +#endif #ifdef USE_COREML #include "core/providers/coreml/coreml_provider_factory.h" #endif diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index f431f471c4082..b73a17db3ce13 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -561,6 +561,7 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument("--use_snpe", action="store_true", help="Build with SNPE support.") parser.add_argument("--snpe_root", help="Path to SNPE SDK root.") parser.add_argument("--use_nnapi", action="store_true", help="Build with NNAPI support.") + parser.add_argument("--use_vsinpu", action="store_true", help="Build with VSINPU support.") parser.add_argument( "--nnapi_min_api", type=int, help="Minimum Android API level to enable NNAPI, should be no less than 27" ) @@ -1020,6 +1021,7 @@ def generate_build_tree( "-Donnxruntime_BUILD_APPLE_FRAMEWORK=" + ("ON" if args.build_apple_framework else "OFF"), "-Donnxruntime_USE_DNNL=" + ("ON" if args.use_dnnl else "OFF"), "-Donnxruntime_USE_NNAPI_BUILTIN=" + ("ON" if args.use_nnapi else "OFF"), + "-Donnxruntime_USE_VSINPU=" + ("ON" if args.use_vsinpu else "OFF"), "-Donnxruntime_USE_RKNPU=" + ("ON" if args.use_rknpu else "OFF"), "-Donnxruntime_USE_LLVM=" + ("ON" if args.use_tvm else "OFF"), "-Donnxruntime_ENABLE_MICROSOFT_INTERNAL=" + ("ON" if args.enable_msinternal else "OFF"),