diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 2e54063598cc..6092b5dac600 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1164,9 +1164,12 @@ if (onnxruntime_USE_OPENVINO) elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.0") set(OPENVINO_VERSION "2023.0") add_definitions(-DOPENVINO_2023_0=1) + elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.1") + set(OPENVINO_VERSION "2023.1") + add_definitions(-DOPENVINO_2023_1=1) elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "openvino") - set(OPENVINO_VERSION "2023.0") - add_definitions(-DOPENVINO_2023_0=1) + set(OPENVINO_VERSION "2023.1") + add_definitions(-DOPENVINO_2023_1=1) else() message(FATAL_ERROR "Unsupported OpenVINO version: ${INTEL_OPENVINO_DIR}") endif() diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index d215b12157cf..643da416d8ce 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -593,7 +593,7 @@ typedef struct OrtOpenVINOProviderOptions { OrtOpenVINOProviderOptions() : device_type{}, enable_vpu_fast_compile{}, device_id{}, - num_of_threads{}, + num_of_threads{1}, cache_dir{}, context{}, enable_opencl_throttling{}, diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 5969aaeb44f7..1f5c61c11966 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -7,9 +7,6 @@ #include #include "core/providers/shared_library/provider_api.h" - -#include - #include "contexts.h" #include "backend_manager.h" #include "ibackend.h" @@ -36,11 +33,11 @@ BackendManager::BackendManager(const onnxruntime::Node& fused_node, const logging::Logger& logger) { auto prec_str = GetGlobalContext().precision_str; if (prec_str == "FP32") { - subgraph_context_.precision = InferenceEngine::Precision::FP32; + subgraph_context_.precision = "FP32"; } else if (prec_str == "FP16") { - subgraph_context_.precision = InferenceEngine::Precision::FP16; + subgraph_context_.precision = "FP16"; } else if (prec_str == "U8") { - subgraph_context_.precision = InferenceEngine::Precision::U8; + subgraph_context_.precision = "U8"; } else { throw std::string("Invalid OpenVINO Precision type: " + prec_str); } @@ -78,7 +75,6 @@ BackendManager::BackendManager(const onnxruntime::Node& fused_node, LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; if (GetGlobalContext().device_type.find("CPU") != std::string::npos || GetGlobalContext().device_type.find("GPU") != std::string::npos) { - if (GetGlobalContext().enable_dynamic_shapes) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " << "Creating backend Dynamic Shapes"; try { @@ -90,7 +86,6 @@ BackendManager::BackendManager(const onnxruntime::Node& fused_node, } LOGS_DEFAULT(INFO) << "[OpenVINO-EP] " << "Backend created for graph " << subgraph_context_.subgraph_name; - } } } else { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has concrete input dims. Initializing backend for graph " << subgraph_context_.subgraph_name; @@ -257,7 +252,7 @@ void BackendManager::Compute(OrtKernelContext* context) { } #endif bool use_dynamic_backend = true; - if (GetGlobalContext().enable_dynamic_shapes && subgraph_context_.has_dynamic_input_shape && + if (subgraph_context_.has_dynamic_input_shape && (GetGlobalContext().device_type.find("CPU") != std::string::npos || GetGlobalContext().device_type.find("GPU") != std::string::npos)) { concrete_backend_->Infer(context); diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index c5ebdb413199..d49968cdb7f3 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -8,8 +8,8 @@ #include #include "ov_interface.h" -#include -#include +#include "openvino/pass/convert_fp32_to_fp16.hpp" +#include "openvino/pass/constant_folding.hpp" #include "core/providers/shared_library/provider_api.h" #include "backend_utils.h" @@ -50,14 +50,14 @@ struct static_cast_int64 { std::shared_ptr CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context, const SubGraphContext& subgraph_context, - std::map>& const_outputs_map) { + std::map>& const_outputs_map) { if (IsCILogEnabled()) { std::cout << "CreateNgraphFunc" << std::endl; } const std::string model = model_proto.SerializeAsString(); try { auto cnn_network = global_context.ie_core.ReadModel(model); - if ((subgraph_context.precision == InferenceEngine::Precision::FP16) && + if ((subgraph_context.precision == "FP16") && (global_context.device_type.find("VPUX") == std::string::npos)) { // FP16 transformations ov::pass::ConvertFP32ToFP16 pass_obj; @@ -88,7 +88,7 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext size_t index = results.size() - 1; for (auto it = results.rbegin(); it != results.rend(); ++it) { - if (auto const_node = std::dynamic_pointer_cast((*it)->input_value(0).get_node_shared_ptr())) { + if (auto const_node = std::dynamic_pointer_cast((*it)->input_value(0).get_node_shared_ptr())) { const_outputs_map[(*it)->get_friendly_name()] = const_node; results.erase(results.begin() + index); } @@ -96,12 +96,11 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext } } #ifndef NDEBUG -#if defined(OPENVINO_2022_3) || (OPENVINO_2023_0) +#if defined(OPENVINO_2022_3) || (OPENVINO_2023_0) || (OPENVINO_2023_1) if (IsDebugEnabled()) { std::string name = cnn_network->get_friendly_name(); ov::pass::Serialize serializer(name + ".xml", name + ".bin"); serializer.run_on_model(cnn_network); - ngraph::plot_graph(cnn_network, name + "_executable" + ".dot"); } #endif #endif @@ -111,31 +110,6 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext } } -InferenceEngine::Precision ConvertPrecisionONNXToOpenVINO(const ONNX_NAMESPACE::TypeProto& onnx_type) { - ONNX_NAMESPACE::DataType type_string = ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(onnx_type); - if (*type_string == "float" || *type_string == "tensor(float)") { - return InferenceEngine::Precision::FP32; - } else if (*type_string == "float16" || *type_string == "tensor(float16)") { - return InferenceEngine::Precision::FP16; - } else if (*type_string == "int32" || *type_string == "tensor(int32)") { - return InferenceEngine::Precision::I32; - } else if (*type_string == "int16" || *type_string == "tensor(int16)") { - return InferenceEngine::Precision::I16; - } else if (*type_string == "int8" || *type_string == "tensor(int8)") { - return InferenceEngine::Precision::I8; - } else if (*type_string == "uint16" || *type_string == "tensor(uint16)") { - return InferenceEngine::Precision::U16; - } else if (*type_string == "uint8" || *type_string == "tensor(uint8)") { - return InferenceEngine::Precision::U8; - } else if (*type_string == "bool" || *type_string == "tensor(bool)") { - return InferenceEngine::Precision::U8; - } else if (*type_string == "int64" || *type_string == "tensor(int64)") { - return InferenceEngine::Precision::I32; - } else { - throw std::string(log_tag + "Unsupported Data type"); - } -} - Ort::UnownedValue GetOutputTensor(Ort::KernelContext& context, size_t batch_size, OVInferRequestPtr infer_request, @@ -166,7 +140,7 @@ Ort::UnownedValue GetOutputTensor(Ort::KernelContext& context, std::string output_name, std::unordered_map output_names, - std::shared_ptr node) { + std::shared_ptr node) { // Find position of '/' in the output_name int pos = output_name.find("/"); // Copy the substring from start to pos @@ -210,25 +184,25 @@ int GetFirstAvailableDevice(GlobalContext& global_context) { return i; } -void FillOutputsWithConstantData(std::shared_ptr node, Ort::UnownedValue& out_tensor) { +void FillOutputsWithConstantData(std::shared_ptr node, Ort::UnownedValue& out_tensor) { switch (node->get_element_type()) { - case ngraph::element::Type_t::f32: { + case ov::element::Type_t::f32: { FillOutputHelper(out_tensor, node); break; } - case ngraph::element::Type_t::boolean: { + case ov::element::Type_t::boolean: { FillOutputHelper(out_tensor, node); break; } - case ngraph::element::Type_t::i32: { + case ov::element::Type_t::i32: { FillOutputHelper(out_tensor, node); break; } - case ngraph::element::Type_t::i64: { + case ov::element::Type_t::i64: { FillOutputHelper(out_tensor, node); break; } - case ngraph::element::Type_t::f16: { + case ov::element::Type_t::f16: { FillOutputHelper(out_tensor, node); break; } @@ -237,14 +211,22 @@ void FillOutputsWithConstantData(std::shared_ptr node, Ort::Unowne } } +#if defined(_MSC_VER) +#pragma warning(disable : 4127) +#endif + template -void FillOutputHelper(Ort::UnownedValue& out_tensor, std::shared_ptr node) { - auto const_node = std::dynamic_pointer_cast(node); +void FillOutputHelper(Ort::UnownedValue& out_tensor, std::shared_ptr node) { + auto const_node = std::dynamic_pointer_cast(node); auto res = const_node->cast_vector(); T* tensor_data = out_tensor.GetTensorMutableData(); std::copy(res.begin(), res.end(), tensor_data); } +#if defined(_MSC_VER) +#pragma warning(default : 4127) +#endif + void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx, std::string input_name, Ort::KernelContext& context, const SubGraphContext& subgraph_context) { diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index e0fdc6f55a4d..de78a150fe2d 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -32,19 +32,16 @@ bool IsCILogEnabled(); int GetFirstAvailableDevice(GlobalContext& global_context); -void FillOutputsWithConstantData(std::shared_ptr node, Ort::UnownedValue& out_tensor); +void FillOutputsWithConstantData(std::shared_ptr node, Ort::UnownedValue& out_tensor); template -void FillOutputHelper(Ort::UnownedValue& out_tensor, std::shared_ptr node); +void FillOutputHelper(Ort::UnownedValue& out_tensor, std::shared_ptr node); Ort::UnownedValue GetOutputTensor(Ort::KernelContext& context, std::string output_name, std::unordered_map output_names, - std::shared_ptr node); - -InferenceEngine::Precision -ConvertPrecisionONNXToOpenVINO(const ONNX_NAMESPACE::TypeProto& onnx_type); + std::shared_ptr node); Ort::UnownedValue GetOutputTensor(Ort::KernelContext& context, size_t batch_size, @@ -61,7 +58,7 @@ void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, std::shared_ptr CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context, const SubGraphContext& subgraph_context, - std::map>& const_outputs_map); + std::map>& const_outputs_map); void printPerformanceCounts(const std::vector& performanceMap, std::ostream& stream, std::string deviceName); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 9bdbdbad592e..31b634e9a036 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -9,7 +9,6 @@ #include "core/providers/shared_library/provider_api.h" #include "../backend_utils.h" -#include #include "basic_backend.h" #include "../backend_manager.h" @@ -45,6 +44,7 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, } #endif try { + std::string dev_prec = global_context.device_type + "_" + global_context_.precision_str; if (global_context.is_wholly_supported_graph) { #if defined(IO_BUFFER_ENABLED) if ((global_context.device_type.find("GPU") != std::string::npos) && @@ -57,7 +57,7 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } else { #if defined(OPENVINO_2023_0) - if (subgraph_context.precision != InferenceEngine::Precision::FP16) { + if (!subgraph_context_.has_dynamic_input_shape && dev_prec != "CPU_FP16") { const std::string model = model_proto.SerializeAsString(); exe_network_ = global_context_.ie_core.LoadNetwork(model, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; @@ -72,8 +72,8 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; #endif #else -#if defined(OPENVINO_2023_0) - if (subgraph_context.precision != InferenceEngine::Precision::FP16 && global_context_.enable_dynamic_shapes == false) { +#if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) + if (!subgraph_context_.has_dynamic_input_shape && dev_prec != "CPU_FP16") { const std::string model = model_proto.SerializeAsString(); exe_network_ = global_context_.ie_core.LoadNetwork(model, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; @@ -111,7 +111,7 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, inferRequestsQueue_ = std::unique_ptr(new InferRequestsQueue(exe_network_, nireq)); } - bool BasicBackend::ValidateSubgraph(std::map> & const_outputs_map) { + bool BasicBackend::ValidateSubgraph(std::map> & const_outputs_map) { if (const_outputs_map.size() == subgraph_context_.output_names.size()) subgraph_context_.is_constant = true; if (subgraph_context_.is_constant) { @@ -122,17 +122,20 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, } void BasicBackend::PopulateConfigValue(ov::AnyMap & device_config) { - // Set inference precision if device_type != AUTO - // if (global_context_.device_type.find("GPU_FP16")!= std::string::npos){ - // device_config.emplace(ov::hint::inference_precision(global_context_.precision_str)); - // } device_config = {}; + // Set inference precision based on device precision for OV backend + if (global_context_.precision_str.find("FP16") != std::string::npos && global_context_.device_type == "GPU") { + device_config.emplace(ov::hint::inference_precision("f16")); + } + if (global_context_.precision_str.find("FP32") != std::string::npos) { + device_config.emplace(ov::hint::inference_precision("f32")); + } #ifndef NDEBUG if (openvino_ep::backend_utils::IsDebugEnabled()) { device_config.emplace(ov::enable_profiling(true)); } #endif -#if defined(OPENVINO_2023_0) +#if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) if (global_context_.device_type.find("VPUX") != std::string::npos) { std::pair device_property; device_property = std::make_pair("VPUX_COMPILER_TYPE", "MLIR"); @@ -160,7 +163,10 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, void BasicBackend::EnableGPUThrottling(ov::AnyMap & device_config) { if (global_context_.enable_opencl_throttling == true && global_context_.device_type.find("GPU") != std::string::npos) { LOGS_DEFAULT(INFO) << log_tag << "Enabled OpenCL queue throttling for GPU device"; - device_config[GPU_CONFIG_KEY(PLUGIN_THROTTLE)] = "1"; + std::pair device_property; + device_property = std::make_pair("PLUGIN_THROTTLE", "1"); + device_config.emplace(ov::device::properties("GPU_CONFIG_KEY", device_property)); + // device_config[GPU_CONFIG_KEY(PLUGIN_THROTTLE)] = "1"; } } @@ -190,7 +196,6 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, } size_t batch_slice_idx = 0; if (subgraph_context_.has_dynamic_input_shape && - global_context_.enable_dynamic_shapes == true && (global_context_.device_type.find("CPU") != std::string::npos || global_context_.device_type.find("GPU") != std::string::npos)) { auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 8cdb758fe782..a5df3f49f76f 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -31,7 +31,7 @@ class BasicBackend : public IBackend { private: bool ImportBlob(std::string hw_target, bool vpu_status); void PopulateCompiledDirectory(std::string, std::string&, std::string&, bool&); - bool ValidateSubgraph(std::map>& const_outputs_map); + bool ValidateSubgraph(std::map>& const_outputs_map); void PopulateConfigValue(ov::AnyMap& device_config); void EnableCaching(); void EnableGPUThrottling(ov::AnyMap& device_config); @@ -48,7 +48,7 @@ class BasicBackend : public IBackend { mutable std::mutex compute_lock_; std::shared_ptr ie_cnn_network_; OVExeNetwork exe_network_; - std::map> const_outputs_map_; + std::map> const_outputs_map_; std::unique_ptr inferRequestsQueue_; #if defined IO_BUFFER_ENABLED OVRemoteContextPtr remote_context_; diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index a6011590fafe..21fb43b10e20 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -40,7 +40,7 @@ struct SubGraphContext { std::vector input_indexes; std::unordered_map input_names; std::unordered_map output_names; - OVPrecision precision; + std::string precision; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 05eec6b1013b..d0804153a565 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -137,6 +137,10 @@ OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, openvino_ep::GetCapability obj(graph_viewer, openvino_ep::BackendManager::GetGlobalContext().device_type, "V_2023_0"); result = obj.Execute(); +#elif defined(OPENVINO_2023_1) + openvino_ep::GetCapability obj(graph_viewer, + openvino_ep::BackendManager::GetGlobalContext().device_type, "V_2023_1"); + result = obj.Execute(); #endif return result; diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index d118b37f8ab6..65245c6d963c 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -64,8 +64,31 @@ struct OpenVINO_Provider : Provider { std::shared_ptr CreateExecutionProviderFactory(const void* void_params) override { auto& params = *reinterpret_cast(void_params); + if (params.device_type != nullptr && !((std::string)params.device_type).empty()) { // check for device_type correctness only if provided, skip checks otherwise + std::string device_type = params.device_type; + std::set ov_supported_device_types = {"CPU_FP32", "CPU_FP16", "GPU_FP32", + "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", + "GPU.0_FP16", "GPU.1_FP16", + "NPU_FP16", "NPU_U8"}; + + if (!((ov_supported_device_types.find(device_type) != ov_supported_device_types.end()) || + (device_type.find("HETERO:") == 0) || (device_type.find("MULTI:") == 0) || (device_type.find("AUTO:") == 0))) { + LOGS_DEFAULT(ERROR) << "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. Provided:" + << device_type << "\n " + << "Select from 'CPU_FP32', 'CPU_FP16', 'GPU_FP32', 'GPU.0_FP32', 'GPU.1_FP32', 'GPU_FP16', " + "'GPU.0_FP16', 'GPU.1_FP16', 'NPU_FP16', 'NPU_U8' or from" + " HETERO/MULTI/AUTO options available. \n"; + } + } + int num_of_threads = params.num_of_threads; + if (num_of_threads <= 0) { + num_of_threads = 1; + LOGS_DEFAULT(WARNING) << "[OpenVINO-EP] The value for the key 'num_threads' should be in the positive range.\n " + << "Executing with num_threads=1"; + } + return std::make_shared(params.device_type, params.enable_vpu_fast_compile, - params.device_id, params.num_of_threads, + params.device_id, num_of_threads, params.cache_dir, params.context, params.enable_opencl_throttling, params.enable_dynamic_shapes); diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 9175f51b120d..ce8f2cc49a86 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -42,7 +42,7 @@ OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, std } } -#if defined(OPENVINO_2023_0) +#if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) OVExeNetwork OVCore::LoadNetwork(const std::string& model, std::string& hw_target, ov::AnyMap& device_config, std::string name) { ov::CompiledModel obj; try { @@ -75,8 +75,14 @@ OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& model, OVRemoteCont #endif std::vector OVCore::GetAvailableDevices() { - auto obj = oe.get_available_devices(); - return obj; + auto available_devices = oe.get_available_devices(); + for (int i = 0; i < int(available_devices.size()); i++) { + if (available_devices[i].find("GPU") != std::string::npos) { + std::string luid_str = oe.get_property(available_devices[i], ov::device::luid.name()).as(); + available_devices[i] = available_devices[i] + "_" + luid_str; + } + } + return available_devices; } OVInferRequest OVExeNetwork::CreateInferRequest() { diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 84268ab6dcb1..1517cf23b8b1 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -5,11 +5,12 @@ #include -#include -#if defined(OPENVINO_2022_1) || (OPENVINO_2022_2) || (OPENVINO_2022_3) || (OPENVINO_2023_0) +#if defined(OPENVINO_2022_1) || (OPENVINO_2022_2) || (OPENVINO_2022_3) || (OPENVINO_2023_0) || (OPENVINO_2023_1) #define OV_API_20 #include "openvino/openvino.hpp" #include "openvino/pass/convert_fp32_to_fp16.hpp" +#else +#include #endif #ifdef IO_BUFFER_ENABLED @@ -26,10 +27,8 @@ class OVCore; class OVInferRequest; class OVExeNetwork; -typedef InferenceEngine::Precision OVPrecision; typedef ov::Tensor OVTensor; typedef ov::ProfilingInfo OVProfilingInfo; -typedef ov::AnyMap OVConfig; typedef ov::Model OVNetwork; typedef std::shared_ptr OVInferRequestPtr; typedef std::shared_ptr OVTensorPtr; @@ -45,7 +44,7 @@ class OVCore { public: std::shared_ptr ReadModel(const std::string& model_stream) const; OVExeNetwork LoadNetwork(std::shared_ptr& ie_cnn_network, std::string& hw_target, ov::AnyMap& device_config, std::string name); -#if defined(OPENVINO_2023_0) +#if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) OVExeNetwork LoadNetwork(const std::string& model_stream, std::string& hw_target, ov::AnyMap& device_config, std::string name); #endif void SetCache(std::string cache_dir_path); diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 88db95c1fbe9..1fe62b3466f8 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -17,8 +17,7 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" #endif -#include -#include + #if defined(_MSC_VER) #pragma warning(default : 4244 4245) #elif __GNUC__ @@ -123,6 +122,7 @@ std::vector supported_op_mode = { {"Dropout", V_2023_0, {"VPUX"}}, {"Elu", V_2020_4, {"CPU", "GPU"}}, {"Elu", V_2023_0, {"VPUX"}}, + {"Einsum", V_2023_0, {"CPU", "GPU"}}, {"Equal", V_2020_4, {"CPU", "GPU"}}, {"Equal", V_2023_0, {"VPUX"}}, // Added for whisper decoder model. {"Erf", V_2020_4, {"CPU", "GPU"}}, @@ -151,6 +151,7 @@ std::vector supported_op_mode = { {"GreaterOrEqual", V_2022_1, {"CPU", "GPU"}}, {"GreaterOrEqual", V_2023_0, {"VPUX"}}, {"GridSample", V_2022_3, {"CPU"}}, + {"GridSample", V_2023_0, {"GPU"}}, {"Identity", V_2020_4, {"CPU", "GPU"}}, {"Identity", V_2023_0, {"VPUX"}}, // NoOP {"If", V_2022_3, {"CPU", "GPU"}}, @@ -192,6 +193,7 @@ std::vector supported_op_mode = { {"Neg", V_2023_0, {"VPUX"}}, {"NonMaxSuppression", V_2021_1, {"CPU", "GPU"}}, {"NonZero", V_2021_1, {"CPU"}}, + {"NonZero", V_2023_0, {"GPU"}}, {"Not", V_2021_1, {"CPU", "GPU"}}, {"Not", V_2020_4, {"CPU", "GPU"}}, {"OneHot", V_2020_4, {"CPU", "GPU"}}, @@ -205,6 +207,7 @@ std::vector supported_op_mode = { {"QLinearMatMul", V_2022_3, {"CPU"}}, {"QuantizeLinear", V_2021_4, {"CPU", "GPU"}}, {"QuantizeLinear", V_2023_0, {"VPUX"}}, + {"RandomNormal", V_2023_0, {"CPU", "GPU"}}, {"Range", V_2022_1, {"CPU", "GPU"}}, {"Range", V_2023_0, {"VPUX"}}, {"Reciprocal", V_2020_4, {"CPU", "GPU"}}, @@ -335,6 +338,7 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"Div", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"DequantizeLinear", V_2021_4, {"All"}}); no_dimension_supported_.push_back({"Equal", V_2022_1, {"CPU"}}); + no_dimension_supported_.push_back({"Equal", V_2023_0, {"GPU"}}); no_dimension_supported_.push_back({"Floor", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Gather", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Greater", V_2023_0, {"VPUX"}}); @@ -350,6 +354,7 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"ReduceProd", V_2022_1, {"CPU", "GPU"}}); no_dimension_supported_.push_back({"Reshape", V_2022_1, {"All"}}); no_dimension_supported_.push_back({"Shape", V_2022_1, {"GPU"}}); + no_dimension_supported_.push_back({"Shape", V_2023_0, {"CPU"}}); no_dimension_supported_.push_back({"Squeeze", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Sub", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Unsqueeze", V_2020_4, {"All"}}); @@ -1016,8 +1021,10 @@ bool DataOps::node_is_supported(const std::mapdim()) { if (utils::HasDimValue(dim) && dim.dim_value() == 0) { - if ((device_id_.find("GPU") != std::string::npos) && ((optype == "Expand") || - (optype == "Slice") || (optype == "Concat") || (optype == "Shape"))) { + if (((device_id_.find("CPU") != std::string::npos) || (device_id_.find("GPU") != std::string::npos) ) && + ((optype == "Expand") || (optype == "Equal") || + (optype == "Slice") || (optype == "Concat") || + (optype == "Shape"))) { return; } has_unsupported_dimension = true; diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.cc b/onnxruntime/core/providers/openvino/ov_versions/utils.cc index c3b47aae4fa9..a3e141a0a636 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.cc @@ -9,8 +9,15 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" #endif -#include + +#include "openvino/core/deprecated.hpp" +#define IN_OV_COMPONENT +#define NGRAPH_LEGACY_HEADER_INCLUDED #include + +#undef NGRAPH_LEGACY_HEADER_INCLUDED +#undef IN_OV_COMPONENT + #if defined(_MSC_VER) #pragma warning(default : 4244 4245) #elif __GNUC__ @@ -87,6 +94,7 @@ int GetOnnxOpSet(const GraphViewer& graph_viewer) { std::map> GetNgSupportedOps(const int onnx_opset) { std::map> ng_supported_ops; + OPENVINO_SUPPRESS_DEPRECATED_START ng_supported_ops.emplace(kOnnxDomain, ngraph::onnx_import::get_supported_operators(onnx_opset, kOnnxDomain)); const std::set ng_disabled_ops = {"LSTM"}; // Place-holder for ops not supported. @@ -94,7 +102,7 @@ std::map> GetNgSupportedOps(const int onnx_op for (const auto& disabled_op : ng_disabled_ops) { ng_supported_ops.at(kOnnxDomain).erase(disabled_op); } - + OPENVINO_SUPPRESS_DEPRECATED_END return ng_supported_ops; }