diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake
index d738e29101cf..5d1a481d40ab 100644
--- a/cmake/onnxruntime_providers_openvino.cmake
+++ b/cmake/onnxruntime_providers_openvino.cmake
@@ -17,8 +17,8 @@
# Header paths
find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX)
- if(OpenVINO_VERSION VERSION_LESS 2023.0)
- message(FATAL_ERROR "OpenVINO 2023.0 and newer are supported. Please, latest OpenVINO release")
+ if(OpenVINO_VERSION VERSION_LESS 2024.0)
+ message(FATAL_ERROR "OpenVINO 2024.0 and newer are supported. Please, use latest OpenVINO release")
endif()
if (WIN32)
diff --git a/docs/python/ReadMeOV.rst b/docs/python/ReadMeOV.rst
index 6ef16e137813..86914699bbf6 100644
--- a/docs/python/ReadMeOV.rst
+++ b/docs/python/ReadMeOV.rst
@@ -7,6 +7,7 @@ OpenVINO™ Execution Provider for ONNX Runtime accelerates inference across man
- Intel® CPUs
- Intel® integrated GPUs
- Intel® discrete GPUs
+ - Intel® integrated NPUs (Windows only)
Installation
------------
@@ -15,26 +16,27 @@ Requirements
^^^^^^^^^^^^
- Ubuntu 18.04, 20.04, RHEL(CPU only) or Windows 10 - 64 bit
-- Python 3.8 or 3.9 or 3.10 for Linux and only Python3.10 for Windows
+- Python 3.9 or 3.10 or 3.11 for Linux and Python 3.10, 3.11 for Windows
This package supports:
- Intel® CPUs
- Intel® integrated GPUs
- Intel® discrete GPUs
+ - Intel® integrated NPUs (Windows only)
``pip3 install onnxruntime-openvino``
Please install OpenVINO™ PyPi Package separately for Windows.
For installation instructions on Windows please refer to `OpenVINO™ Execution Provider for ONNX Runtime for Windows `_.
-**OpenVINO™ Execution Provider for ONNX Runtime** Linux Wheels comes with pre-built libraries of OpenVINO™ version 2023.0.0 eliminating the need to install OpenVINO™ separately. The OpenVINO™ libraries are prebuilt with CXX11_ABI flag set to 0.
+**OpenVINO™ Execution Provider for ONNX Runtime** Linux Wheels comes with pre-built libraries of OpenVINO™ version 2024.1.0 eliminating the need to install OpenVINO™ separately.
For more details on build and installation please refer to `Build `_.
Usage
^^^^^
-By default, Intel® CPU is used to run inference. However, you can change the default option to either Intel® integrated or discrete GPU.
+By default, Intel® CPU is used to run inference. However, you can change the default option to either Intel® integrated GPU, discrete GPU, integrated NPU (Windows only).
Invoke `the provider config device type argument `_ to change the hardware on which inferencing is done.
For more API calls and environment variables, see `Usage `_.
diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc
index f0e76312d6e0..7b518947138a 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc
@@ -3,8 +3,13 @@
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
+#include
#include
+#include
+#include
+#include
+#include "core/common/inlined_containers_fwd.h"
#include "core/graph/extended_graph_edge.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
@@ -17,39 +22,147 @@ namespace onnxruntime {
namespace {
bool CanNodePropagate(const Node& node) {
return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {12}) ||
- graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13, 14, 19}) ||
- graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13}) ||
- graph_utils::IsSupportedOptypeVersionAndDomain(node, "Squeeze", {1, 11, 13}) ||
- graph_utils::IsSupportedOptypeVersionAndDomain(node, "Unsqueeze", {1, 11, 13});
+ graph_utils::IsSupportedOptypeVersionAndDomain(node, "Reshape", {5, 13, 14, 19, 21}) ||
+ graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13, 21}) ||
+ graph_utils::IsSupportedOptypeVersionAndDomain(node, "Squeeze", {1, 11, 13, 21}) ||
+ graph_utils::IsSupportedOptypeVersionAndDomain(node, "Unsqueeze", {1, 11, 13, 21}) ||
+ graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13});
}
-// convert this: src_node -> dst_node
-// to this: src_node -> Q -> DQ -> dst_node
-// assumptions:
-// 1. insertion_edge is valid - node indexes refer to valid nodes, arg name refers to a valid NodeArg, and it
-// corresponds to an actual graph relationship
-// 2. scale_initializer_nodearg and zp_initializer_nodearg_ptr (if not null) are constant initializers
-Status InsertQDQPair(Graph& graph, const ExtendedGraphEdge& insertion_edge,
- NodeArg& scale_initializer_nodearg, NodeArg* zp_initializer_nodearg_ptr,
- const std::string& qdq_domain, const logging::Logger& logger) {
- auto* src_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Source);
- auto* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
-
- ORT_ENFORCE(src_node || dst_node, "At least one graph node must be specified in the propagation edge.");
-
- const auto& base_name = insertion_edge.arg_name;
+// Makes matching attributes for new QuantizeLinear nodes from an existing DequantizeLinear node.
+NodeAttributes MakeQAttrsFromDQ(const Node& dq_node) {
+ assert(dq_node.SinceVersion() <= 21); // Checked by previous call to QDQ::MatchDQNode().
+ // In opset <= 21, all DQ attributes (i.e., axis and block_size) are also Q attributes.
+ // So, set a copy of the DQ attributes.
+ return dq_node.GetAttributes();
+}
+
+// Makes matching attributes for new DequantizeLinear nodes from an existing QuantizeLinear node.
+NodeAttributes MakeDQAttrsFromQ(const Node& q_node) {
+ assert(q_node.SinceVersion() <= 21); // Checked by previous call to QDQ::MatchQNode().
+ const NodeAttributes& q_attrs = q_node.GetAttributes();
+ if (q_attrs.empty()) {
+ return {};
+ }
+
+ // In opset <= 21, only the "axis" and "block_size" attributes for Q are also DQ attributes.
+ NodeAttributes dq_attrs;
+
+ auto axis_attr_it = q_attrs.find("axis");
+ if (axis_attr_it != q_attrs.end()) {
+ dq_attrs.insert({axis_attr_it->first, axis_attr_it->second});
+ }
+
+ auto block_size_attr_it = q_attrs.find("block_size");
+ if (block_size_attr_it != q_attrs.end()) {
+ dq_attrs.insert({block_size_attr_it->first, block_size_attr_it->second});
+ }
+
+ return dq_attrs;
+}
+
+// Validates edges into which to insert Q -> DQ ops.
+// - Must have at least one edge.
+// - All edges must correspond to the same graph NodeArg (i.e., same source but potentially different destination).
+// - All edges must be attached to either a source node or a destination node.
+Status ValidateQDQInsertionEdges(Graph& graph, gsl::span insertion_edges) {
+ const size_t num_edges = insertion_edges.size();
+ ORT_RETURN_IF(num_edges == 0, "Expected at least one edge into which to insert QDQ pair.");
+
+ const ExtendedGraphEdge& first_edge = insertion_edges[0];
+ const Node* src_node = first_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Source);
+ const Node* first_dst_node = first_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
+ const std::string& node_arg_name = first_edge.arg_name;
+ ORT_RETURN_IF_NOT(graph.GetNodeArg(node_arg_name) != nullptr,
+ "QDQ insertion edge does not have a valid graph NodeArg for ", node_arg_name);
+ ORT_RETURN_IF_NOT(src_node != nullptr || first_dst_node != nullptr,
+ "QDQ insertion edge [0] for NodeArg ", node_arg_name,
+ " must have a source or a destination node");
+
+ for (size_t i = 1; i < num_edges; i++) {
+ const ExtendedGraphEdge& insertion_edge = insertion_edges[i];
+ ORT_RETURN_IF_NOT(insertion_edge.arg_name == node_arg_name,
+ "QDQ insertion edge [", i, "] has NodeArg ", insertion_edge.arg_name,
+ " but expected NodeArg ", node_arg_name);
+
+ const Node* edge_dst_node = insertion_edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
+ ORT_RETURN_IF_NOT(src_node != nullptr || edge_dst_node != nullptr,
+ "QDQ insertion edge [", i, "] for NodeArg ", node_arg_name,
+ " must have a source or a destination node");
+ }
+
+ return Status::OK();
+}
+
+// Logs information about the edges into which Q/DQ nodes will be inserted in InsertQDQPairs().
+// Assumes the edges have already been validated.
+void LogQDQInsertion(const logging::Logger& logger, logging::Severity severity, const CodeLocation& code_location,
+ const Graph& graph, gsl::span edges) {
+ auto logging_data_type = logging::DataType::SYSTEM;
+ if (!logger.OutputIsEnabled(severity, logging_data_type)) {
+ return;
+ }
+
+ const Node* src_node = edges[0].GetNodeAtEnd(graph, ExtendedGraphEdge::End::Source);
+ const auto& node_arg_name = edges[0].arg_name;
+ std::string src_label = src_node ? MakeString("node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")")
+ : "input";
+ std::ostringstream dst_labels;
+ const size_t num_edges = edges.size();
+
+ for (size_t i = 0; i < num_edges; ++i) {
+ const ExtendedGraphEdge& edge = edges[i];
+ const Node* dst_node = edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
+ dst_labels << (dst_node ? MakeString("dst node (\"", dst_node->Name(), "\", index: ", dst_node->Index(), ")")
+ : "output")
+ << (i == num_edges - 1 ? "" : ",");
+ }
+
+ logging::Capture(logger, severity, logging::Category::onnxruntime, logging_data_type, code_location).Stream()
+ << "Inserted Q/DQ pair between "
+ << (src_node ? MakeString("src node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")")
+ : "input")
+ << " and " << dst_labels.str()
+ << " at NodeArg \"" << node_arg_name << "\".";
+}
+
+// convert this: src_node (or graph input) --+--> dst_node_0 (or graph output)
+// |
+// +--> dst_node_1
+// | ...
+// +--> dst_node_n
+//
+// to this: src_node (or graph input) -> Q --+--> DQ -> dst_node_0 (or graph output)
+// |
+// +--> DQ -> dst_node_1
+// | ...
+// +--> DQ -> dst_node_n
+// Checks that all insertion edges share the same NodeArg. That is, the edges originate from the same source node
+// output. If there is no src_node, then all edges should come from the same graph input.
+// This function returns an error status if edges are invalid.
+//
+// Assumes that scale_initializer_nodearg and zp_initializer_nodearg_ptr (if not null) are constant initializers.
+Status InsertQDQPairs(Graph& graph, gsl::span insertion_edges,
+ NodeArg& scale_initializer_nodearg, NodeArg* zp_initializer_nodearg_ptr,
+ const std::string& qdq_domain, const NodeAttributes& q_attrs, const NodeAttributes& dq_attrs,
+ const logging::Logger& logger) {
+ ORT_RETURN_IF_ERROR(ValidateQDQInsertionEdges(graph, insertion_edges));
+
+ const ExtendedGraphEdge& first_edge = insertion_edges[0]; // ValidateQDQInsertionEdges() guarantees at least one edge
+
+ Node* src_node = first_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Source); // nullptr for graph input
+ const auto& base_name = first_edge.arg_name;
auto& base_node_arg = *graph.GetNodeArg(base_name);
- LOGS(logger, VERBOSE) << "Inserting Q/DQ pair between "
- << (src_node ? MakeString("node (\"", src_node->Name(), "\", index: ", src_node->Index(), ")")
- : "input")
- << " and "
- << (dst_node ? MakeString("node (\"", dst_node->Name(), "\", index: ", dst_node->Index(), ")")
- : "output")
- << " at NodeArg \"" << base_name << "\".";
+ LogQDQInsertion(logger, logging::Severity::kVERBOSE, ORT_WHERE, graph, insertion_edges);
- // set up new NodeArgs
- auto& pre_q_nodearg = insertion_edge.HasGraphInputOrInitializer()
+ auto make_q_or_dq_inputs = [](NodeArg& data, NodeArg& scale, NodeArg* zero_point) {
+ return zero_point ? InlinedVector{&data, &scale, zero_point}
+ : InlinedVector{&data, &scale};
+ };
+
+ // Create Q node that will be inserted after src_node
+ auto& pre_q_nodearg = first_edge.HasGraphInputOrInitializer()
? base_node_arg
: graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_pre_q"),
nullptr);
@@ -57,17 +170,6 @@ Status InsertQDQPair(Graph& graph, const ExtendedGraphEdge& insertion_edge,
auto& q_to_dq_nodearg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_q_to_dq"),
nullptr);
- auto& post_dq_nodearg = insertion_edge.HasGraphOutput()
- ? base_node_arg
- : graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_name + "_post_dq"),
- nullptr);
-
- // set up new Nodes
- auto make_q_or_dq_inputs = [](NodeArg& data, NodeArg& scale, NodeArg* zero_point) {
- return zero_point ? std::vector{&data, &scale, zero_point}
- : std::vector{&data, &scale};
- };
-
auto& q_node = graph.AddNode(graph.GenerateNodeName(base_name + "_q"),
QDQ::QOpName,
"Inserted by QDQPropagationTransformer",
@@ -76,40 +178,61 @@ Status InsertQDQPair(Graph& graph, const ExtendedGraphEdge& insertion_edge,
zp_initializer_nodearg_ptr),
// outputs
{&q_to_dq_nodearg},
- nullptr, // attributes
+ &q_attrs, // attributes
qdq_domain);
ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(q_node), "Failed to set op schema for added Q node.");
- auto& dq_node = graph.AddNode(graph.GenerateNodeName(base_name + "_dq"),
- QDQ::DQOpName,
- "Inserted by QDQPropagationTransformer",
- // inputs
- make_q_or_dq_inputs(q_to_dq_nodearg, scale_initializer_nodearg,
- zp_initializer_nodearg_ptr),
- // outputs
- {&post_dq_nodearg},
- nullptr, // attributes
- qdq_domain);
-
- ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(dq_node), "Failed to set op schema for added DQ node.");
-
- // set up edges
- if (src_node && dst_node) {
- graph.RemoveEdge(src_node->Index(), dst_node->Index(),
- insertion_edge.src->arg_idx, insertion_edge.dst->arg_idx);
- }
-
if (src_node) {
- src_node->MutableOutputDefs()[insertion_edge.src->arg_idx] = &pre_q_nodearg;
- graph.AddEdge(src_node->Index(), q_node.Index(), insertion_edge.src->arg_idx, 0);
- }
+ // Remove original edges between src and dst nodes.
+ for (const auto& insertion_edge : insertion_edges) {
+ auto* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
+
+ if (dst_node) {
+ graph.RemoveEdge(src_node->Index(), dst_node->Index(),
+ insertion_edge.src->arg_idx, insertion_edge.dst->arg_idx);
+ }
+ }
- graph.AddEdge(q_node.Index(), dq_node.Index(), 0, 0);
+ // Add edge from src to Q node.
+ src_node->MutableOutputDefs()[first_edge.src->arg_idx] = &pre_q_nodearg;
+ graph.AddEdge(src_node->Index(), q_node.Index(), first_edge.src->arg_idx, 0);
+ }
- if (dst_node) {
- dst_node->MutableInputDefs()[insertion_edge.dst->arg_idx] = &post_dq_nodearg;
- graph.AddEdge(dq_node.Index(), dst_node->Index(), 0, insertion_edge.dst->arg_idx);
+ // Create a DQ node for each dst node and connect remaining edges.
+ for (size_t edge_idx = 0; edge_idx < insertion_edges.size(); ++edge_idx) {
+ const auto& insertion_edge = insertion_edges[edge_idx];
+ const std::string edge_suffix = edge_idx == 0 ? "" : std::to_string(edge_idx);
+ auto& post_dq_nodearg = insertion_edge.HasGraphOutput()
+ ? base_node_arg
+ : graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(MakeString(base_name,
+ "_post_dq",
+ edge_suffix)),
+ nullptr);
+
+ auto& dq_node = graph.AddNode(graph.GenerateNodeName(MakeString(base_name, "_dq", edge_suffix)),
+ QDQ::DQOpName,
+ "Inserted by QDQPropagationTransformer",
+ // inputs
+ make_q_or_dq_inputs(q_to_dq_nodearg, scale_initializer_nodearg,
+ zp_initializer_nodearg_ptr),
+ // outputs
+ {&post_dq_nodearg},
+ &dq_attrs, // attributes
+ qdq_domain);
+
+ ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(dq_node), "Failed to set op schema for added DQ node.");
+
+ Node* dst_node = insertion_edge.GetMutableNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
+
+ // Add edge from Q to DQ
+ graph.AddEdge(q_node.Index(), dq_node.Index(), 0, 0);
+
+ // Add edge from DQ to dst_node
+ if (dst_node) {
+ dst_node->MutableInputDefs()[insertion_edge.dst->arg_idx] = &post_dq_nodearg;
+ graph.AddEdge(dq_node.Index(), dst_node->Index(), 0, insertion_edge.dst->arg_idx);
+ }
}
return Status::OK();
@@ -156,37 +279,39 @@ std::optional GetPreviousPropagationEdge(const Graph& graph,
return GetPreviousEdge(graph, *src_node);
}
-std::optional GetNextEdge(const Graph& graph, const Node& node) {
- // for now we can just consider the first output (index 0)
+InlinedVector GetNextEdges(const Graph& graph, const Node& node) {
+ constexpr int node_output_index = 0; // for now we can just consider the first output (index 0)
+ InlinedVector next_edges;
+ const auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node, static_cast(node_output_index));
- const auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node, 0);
- if (output_edges.empty()) {
- // maybe edge to output
- return ExtendedGraphEdge::TryCreateFromNodeToOutput(graph, node, 0);
+ // edges to next nodes
+ for (const auto& output_edge : output_edges) {
+ next_edges.push_back(ExtendedGraphEdge::CreateFromValidGraphEdge(output_edge));
}
- if (!graph.IsOutput(node.OutputDefs()[0]) && output_edges.size() == 1) {
- // single edge to next node
- return ExtendedGraphEdge::CreateFromValidGraphEdge(output_edges.front());
+ // maybe edge to graph output
+ auto edge_to_output = ExtendedGraphEdge::TryCreateFromNodeToOutput(graph, node, node_output_index);
+ if (edge_to_output.has_value()) {
+ next_edges.push_back(edge_to_output.value());
}
- return std::nullopt;
+ return next_edges;
}
-std::optional GetNextPropagationEdge(const Graph& graph,
- const ExtendedGraphEdge& edge) {
+InlinedVector GetNextPropagationEdges(const Graph& graph,
+ const ExtendedGraphEdge& edge) {
if (edge.HasGraphOutput()) {
- return std::nullopt;
+ return {};
}
const auto* dst_node = edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
ORT_ENFORCE(dst_node != nullptr);
if (!CanNodePropagate(*dst_node)) {
- return std::nullopt;
+ return {};
}
- return GetNextEdge(graph, *dst_node);
+ return GetNextEdges(graph, *dst_node);
}
class GraphConstantInitializerGetter {
@@ -228,21 +353,54 @@ Status PropagateDQForward(Graph& graph, gsl::span node_indices,
? dq_node.MutableInputDefs()[QDQ::InputIndex::ZERO_POINT_ID]
: nullptr;
- const auto edge_after_dq = GetNextEdge(graph, dq_node);
- if (!edge_after_dq) {
+ const InlinedVector edges_after_dq = GetNextEdges(graph, dq_node);
+ if (edges_after_dq.size() != 1) {
continue;
}
- for (auto curr_edge = GetNextPropagationEdge(graph, *edge_after_dq);
- curr_edge.has_value();
- curr_edge = GetNextPropagationEdge(graph, *curr_edge)) {
- if (const auto* dst_node = curr_edge->GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
- dst_node && QDQ::MatchQNode(*dst_node)) {
- break;
+ // Utility function to check if any edge out of a node (e.g., Transpose) ends in a Q node.
+ auto any_edge_ends_in_q = [](Graph& graph, const InlinedVector& edges) -> bool {
+ for (const auto& edge : edges) {
+ const auto* edge_dst_node = edge.GetNodeAtEnd(graph, ExtendedGraphEdge::End::Destination);
+ if (edge_dst_node && QDQ::MatchQNode(*edge_dst_node)) {
+ return true;
+ }
+ }
+ return false;
+ };
+
+ // Propagate DQ forward in a BFS traversal of NodeArg edges. A NodeArg "edge group" consists of one or more edges
+ // that all begin at the same source node's output slot and end at a graph output or a destination node.
+ // Ex: The subgraph below shows a NodeArg edge group (containing 3 edges) that begins at a
+ // Transpose, ends at two destination nodes, and produces a graph output.
+ // DQ -> Transpose --+--> Sigmoid -> ...
+ // |
+ // +--> Slice -> ...
+ // |
+ // +--> graph_output
+ std::queue> node_arg_edges;
+ node_arg_edges.push(GetNextPropagationEdges(graph, edges_after_dq[0]));
+
+ while (!node_arg_edges.empty()) {
+ const InlinedVector curr_edge_group = std::move(node_arg_edges.front());
+ node_arg_edges.pop();
+
+ // Skip if edge group is empty. Also, to keep things simple, we do not yet handle edge groups in which
+ // one of the destination nodes is already a QuantizeLinear node. Ex:
+ // DQ -> Transpose --+--> QuantizeLinear -> ...
+ // |
+ // +--> Slice -> ...
+ if (curr_edge_group.empty() || any_edge_ends_in_q(graph, curr_edge_group)) {
+ continue;
}
- ORT_RETURN_IF_ERROR(InsertQDQPair(graph, *curr_edge, dq_scale, dq_zero_point, dq_node.Domain(), logger));
+ ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, curr_edge_group, dq_scale, dq_zero_point, dq_node.Domain(),
+ MakeQAttrsFromDQ(dq_node), dq_node.GetAttributes(), logger));
modified = true;
+
+ for (const auto& edge : curr_edge_group) {
+ node_arg_edges.push(GetNextPropagationEdges(graph, edge));
+ }
}
}
@@ -290,7 +448,8 @@ Status PropagateQBackward(Graph& graph, gsl::span node_indices,
break;
}
- ORT_RETURN_IF_ERROR(InsertQDQPair(graph, *curr_edge, q_scale, q_zero_point, q_node.Domain(), logger));
+ ORT_RETURN_IF_ERROR(InsertQDQPairs(graph, InlinedVector{*curr_edge}, q_scale, q_zero_point,
+ q_node.Domain(), q_node.GetAttributes(), MakeDQAttrsFromQ(q_node), logger));
modified = true;
}
}
diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc
index 1c027e39fa5f..8f3658df0d09 100644
--- a/onnxruntime/core/providers/openvino/backend_manager.cc
+++ b/onnxruntime/core/providers/openvino/backend_manager.cc
@@ -28,9 +28,8 @@ BackendManager::BackendManager(const GlobalContext& global_context,
const onnxruntime::Node& fused_node,
const onnxruntime::GraphViewer& subgraph,
const logging::Logger& logger,
- EPCtxHandler& ctx_handle) {
+ EPCtxHandler& ep_ctx_handle_) {
global_context_ = global_context;
- ep_ctx_handle_ = ctx_handle;
openvino_sdk_version_ = std::to_string(global_context_.OpenVINO_Version.at(0)) + "." +
std::to_string(global_context_.OpenVINO_Version.at(1));
@@ -147,13 +146,20 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie
std::string model_blob_str;
auto compiled_model = concrete_backend_->GetOVCompiledModel();
- auto graph_name = global_context_.onnx_model_path_name;
- // Remove extension so we can append suffix to form the complete name of output graph
- graph_name = [&]() {
- size_t dot = graph_name.find_last_of(".");
- if (dot == std::string::npos) return graph_name;
- return graph_name.substr(0, dot);
- }();
+ std::string graph_name = "";
+ // Epctx file path from SO is mapped to cache_dir variable for OVEP for readability
+ if (global_context_.cache_dir != "") {
+ graph_name = global_context_.cache_dir;
+ } else {
+ graph_name = global_context_.onnx_model_path_name;
+ // Remove extension so we can append suffix to form the complete name of output graph
+ graph_name = [&]() {
+ size_t dot = graph_name.find_last_of(".");
+ if (dot == std::string::npos) return graph_name;
+ return graph_name.substr(0, dot);
+ }();
+ graph_name = graph_name + "-ov_" + GetGlobalContext().device_type + "_blob.onnx";
+ }
// If embed_mode, then pass on the serialized blob
// If not embed_mode, dump the blob here and only pass on the path to the blob
if (global_context_.ep_context_embed_mode) {
@@ -162,9 +168,19 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie
model_blob_str = model_blob_stream.str();
ORT_ENFORCE(model_blob_str.size() != 0);
} else {
- std::ofstream f(graph_name + ".blob", std::ios::out | std::ios::trunc | std::ios::binary);
- compiled_model.export_model(f);
- model_blob_str = graph_name + ".blob";
+ // Remove extension so we can append suffix to form the complete name of output graph
+ auto blob_name = [&]() {
+ size_t dot = graph_name.find_last_of(".");
+ if (dot == std::string::npos) return graph_name;
+ return graph_name.substr(0, dot);
+ }();
+ std::ofstream blob_file(blob_name + ".blob",
+ std::ios::out | std::ios::trunc | std::ios::binary);
+ if (!blob_file) {
+ ORT_THROW("Unable to open file for epctx model dump.");
+ }
+ compiled_model.export_model(blob_file);
+ model_blob_str = blob_name + ".blob";
}
ORT_RETURN_IF_ERROR(ep_ctx_handle_.ExportEPCtxModel(graph_body_viewer,
@@ -172,8 +188,7 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie
logger,
global_context_.ep_context_embed_mode,
model_blob_str,
- openvino_sdk_version_,
- GetGlobalContext().device_type));
+ openvino_sdk_version_));
return Status::OK();
}
@@ -248,7 +263,7 @@ static void DumpOpenVINOEPModel(std::string onnx_model_path_name,
ONNX_NAMESPACE::ModelProto* model_proto,
const onnxruntime::Node& fused_node) {
if (openvino_ep::backend_utils::IsDebugEnabled()) {
- auto model_name = onnx_model_path_name.empty() ? "unknown.onnx" : onnx_model_path_name;
+ auto model_name = onnx_model_path_name.empty() ? "unknown.onnx" : std::move(onnx_model_path_name);
#ifdef _WIN32
size_t slash = model_name.find_last_of("\\");
#else
diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc
index f8046bcb3a06..d79aa35be641 100644
--- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc
+++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc
@@ -37,7 +37,7 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto,
PopulateConfigValue(device_config);
// Enable caching
- EnableCaching();
+ EnableCaching(device_config);
// Setting OpenCL queue throttling for GPU
EnableGPUThrottling(device_config);
@@ -82,26 +82,28 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto,
ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name);
}
#else // !IO_BUFFER_ENABLED
+ std::string prec_str = (global_context_.precision_str != "ACCURACY") ? global_context_.precision_str : global_context_.model_precision;
if (is_ep_ctx_graph_) {
// If the blob is held in an EPContext node, then skip FE+Compile
// and directly move on to creating a backend with the executable blob
exe_network_ = global_context_.ie_core.ImportModel(ep_ctx_handle.GetModelBlobStream(),
hw_target,
device_config,
+ global_context_.ep_context_embed_mode,
subgraph_context_.subgraph_name);
ie_cnn_network_ = exe_network_.Get().get_runtime_model();
- } else if (!subgraph_context_.has_dynamic_input_shape) {
+ } else if ((!subgraph_context_.has_dynamic_input_shape) &&
+ ((hw_target.find("AUTO") == std::string::npos) ||
+ (global_context_.OpenVINO_Version.at(0) >= 2024 && global_context_.OpenVINO_Version.at(1) > 2))) {
+ // Optimized OV compile_model API is supported with AUTO from version 2024.3 and above
// Inputs with static dimenstions
- std::string prec_str = (global_context_.precision_str != "ACCURACY") ? global_context_.precision_str : global_context_.model_precision;
const std::string model = model_proto.SerializeAsString();
exe_network_ = global_context_.ie_core.CompileModel(model,
hw_target,
- prec_str,
- global_context_.cache_dir,
device_config,
subgraph_context_.subgraph_name);
ie_cnn_network_ = exe_network_.Get().get_runtime_model();
- } else { // Inputs with dynamic dimensions
+ } else { // For all other types use ov::Model Type
ie_cnn_network_ = CreateOVModel(model_proto, global_context_, const_outputs_map_);
exe_network_ = global_context_.ie_core.CompileModel(
ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name);
@@ -173,13 +175,19 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
}
}
-void BasicBackend::EnableCaching() {
+void BasicBackend::EnableCaching(ov::AnyMap& device_config) {
// cache_dir argument has no effect when working with an embed-mode EPContext Graph
if (is_ep_ctx_graph_) return;
- if (!global_context_.cache_dir.empty()) {
+ if (!global_context_.cache_dir.empty() && !global_context_.export_ep_ctx_blob) {
LOGS_DEFAULT(INFO) << log_tag << "Enables Caching";
- global_context_.ie_core.SetCache(global_context_.cache_dir, global_context_.device_type);
+ if (global_context_.device_type.find("AUTO:GPU") != std::string::npos) {
+ std::pair device_property;
+ device_property = std::make_pair("CACHE_DIR", global_context_.cache_dir);
+ device_config.emplace(ov::device::properties("GPU", device_property));
+ } else {
+ global_context_.ie_core.SetCache(global_context_.cache_dir);
+ }
}
}
@@ -274,7 +282,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
}
try {
- infer_request->SetTensor(input_name, tensor_ptr);
+ infer_request->SetTensor(std::move(input_name), tensor_ptr);
} catch (const char* msg) {
ORT_THROW(msg);
}
diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h
index 5565223f067b..bcd3161590ba 100644
--- a/onnxruntime/core/providers/openvino/backends/basic_backend.h
+++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h
@@ -37,7 +37,7 @@ class BasicBackend : public IBackend {
void PopulateCompiledDirectory(std::string, std::string&, std::string&, bool&);
bool ValidateSubgraph(std::map>& const_outputs_map);
void PopulateConfigValue(ov::AnyMap& device_config);
- void EnableCaching();
+ void EnableCaching(ov::AnyMap& device_config);
void EnableGPUThrottling(ov::AnyMap& device_config);
void EnableStreams();
void SetNumThreads(ov::AnyMap& device_config);
diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc
index cd1ae6150e1d..e2df9c83f15a 100644
--- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc
+++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc
@@ -19,8 +19,7 @@ Status EPCtxHandler::ExportEPCtxModel(const GraphViewer& graph_viewer,
const logging::Logger& logger,
const bool& ep_context_embed_mode,
const std::string& model_blob_str,
- const std::string& openvino_sdk_version,
- const std::string& device_type) const {
+ const std::string& openvino_sdk_version) const {
auto model_build = graph_viewer.CreateModel(logger);
auto& graph_build = model_build->MainGraph();
@@ -77,9 +76,12 @@ Status EPCtxHandler::ExportEPCtxModel(const GraphViewer& graph_viewer,
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
// Finally, dump the model
- std::ofstream dump(graph_name + "-ov_" + device_type + "_blob.onnx",
- std::ios::out | std::ios::trunc | std::ios::binary);
- model_proto->SerializeToOstream(dump);
+ std::ofstream epctx_onnx_model(graph_name,
+ std::ios::out | std::ios::trunc | std::ios::binary);
+ if (!epctx_onnx_model) {
+ ORT_THROW("Unable to create epctx onnx model file ");
+ }
+ model_proto->SerializeToOstream(epctx_onnx_model);
LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Export blob as EPContext Node";
@@ -90,9 +92,7 @@ Status EPCtxHandler::ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer) {
auto node = graph_viewer.GetNode(0);
auto& attrs = node->GetAttributes();
ORT_ENFORCE(attrs.count(EP_CACHE_CONTEXT) > 0);
-
model_stream_ = std::make_shared(attrs.at(EP_CACHE_CONTEXT).s());
-
LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node";
is_valid_ep_ctx_graph_ = true;
diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h
index b2b9b5bc53d4..610e9fd49c90 100644
--- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h
+++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h
@@ -29,8 +29,7 @@ class EPCtxHandler {
const logging::Logger& logger,
const bool& ep_context_embed_mode,
const std::string& model_blob_str,
- const std::string& openvino_sdk_version,
- const std::string& device_type) const;
+ const std::string& openvino_sdk_version) const;
Status ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer);
bool CheckForOVEPCtxNode(const GraphViewer& graph_viewer, std::string openvino_sdk_version) const;
bool IsValidOVEPCtxGraph() const { return is_valid_ep_ctx_graph_; }
diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc
index 655e1b180388..5627cb2c122f 100644
--- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc
+++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc
@@ -34,6 +34,7 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv
global_context_->export_ep_ctx_blob = info.export_ep_ctx_blob_;
global_context_->enable_qdq_optimizer = info.enable_qdq_optimizer_;
global_context_->disable_cpu_fallback = info.disable_cpu_fallback_;
+ global_context_->ep_context_embed_mode = info.so_epctx_embed_mode_;
// to check if target device is available
// using ie_core capability GetAvailableDevices to fetch list of devices plugged in
@@ -47,7 +48,7 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv
info.device_type_.find("AUTO") != std::string::npos) {
device_found = true;
} else {
- for (std::string device : available_devices) {
+ for (const std::string& device : available_devices) {
if (device.rfind(info.device_type_, 0) == 0) {
if (info.device_type_.find("GPU") != std::string::npos && (info.precision_ == "FP32" ||
info.precision_ == "FP16" ||
diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h
index 050fb91c5177..030e5bba71b6 100644
--- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h
+++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h
@@ -16,16 +16,23 @@
namespace onnxruntime {
+struct OVDevices {
+ ov::Core core;
+ std::vector get_ov_devices() const {
+ return core.get_available_devices();
+ }
+};
+
static void print_build_options() {
std::cout << "[ERROR] INVALID DEVICE BUILD TYPE SPECIFIED" << std::endl;
std::cout << "Specify the keyword HETERO (or) MULTI (or) AUTO followed by the devices in the order of priority "
<< "you want to build"
<< std::endl;
std::cout << "The different hardware devices that can be added with HETERO/MULTI/AUTO build "
- << "are ['CPU','GPU','NPU']"
+ << "are ['CPU','GPU','NPU','GPU.x'] where x = 0,1,2 and so on"
<< std::endl;
std::cout << "An example of how to specify the HETERO or MULTI or AUTO build type. "
- << "Ex: HETERO:GPU,CPU Ex: MULTI:GPU,CPU Ex: AUTO:GPU,CPU"
+ << "Ex: HETERO:GPU,CPU Ex: MULTI:GPU,CPU Ex: AUTO:GPU,CPU Ex: AUTO:GPU.0,CPU Ex: AUTO:GPU.1,CPU"
<< std::endl;
}
@@ -40,7 +47,8 @@ static std::vector split(const std::string& s, char delim) {
return result;
}
-static std::vector parseDevices(const std::string& device_string) {
+static std::vector parseDevices(const std::string& device_string,
+ const std::vector& available_devices) {
std::string comma_separated_devices = device_string;
if (comma_separated_devices.find(":") != std::string::npos) {
comma_separated_devices = comma_separated_devices.substr(comma_separated_devices.find(":") + 1);
@@ -50,8 +58,15 @@ static std::vector parseDevices(const std::string& device_string) {
print_build_options();
ORT_THROW("Invalid device string: " + device_string);
}
- std::vector dev_options = {"CPU", "GPU", "NPU"};
- for (std::string dev : devices) {
+ std::set dev_options = {"CPU", "GPU", "NPU"};
+
+ for (auto& device : available_devices) {
+ if (dev_options.find(device) == dev_options.end()) {
+ auto dev_options_update = dev_options.emplace(device);
+ }
+ }
+
+ for (const std::string& dev : devices) {
if (!std::count(dev_options.begin(), dev_options.end(), dev)) {
print_build_options();
ORT_THROW("Invalid device string: " + device_string);
@@ -75,28 +90,42 @@ struct OpenVINOExecutionProviderInfo {
bool export_ep_ctx_blob_{false};
bool enable_qdq_optimizer_{false};
bool disable_cpu_fallback_{false};
+ bool so_epctx_embed_mode_{true};
OpenVINOExecutionProviderInfo() = delete;
- explicit OpenVINOExecutionProviderInfo(std::string dev_type, std::string precision, bool enable_npu_fast_compile,
- size_t num_of_threads, std::string cache_dir, std::string model_priority,
+ explicit OpenVINOExecutionProviderInfo(const std::string& dev_type, const std::string& precision,
+ bool enable_npu_fast_compile, size_t num_of_threads,
+ const std::string& cache_dir, const std::string& model_priority,
int num_streams, void* context, bool enable_opencl_throttling,
bool disable_dynamic_shapes, bool export_ep_ctx_blob,
- bool enable_qdq_optimizer, bool disable_cpu_fallback)
- : precision_(precision),
+ bool enable_qdq_optimizer, bool disable_cpu_fallback,
+ bool so_epctx_embed_mode)
+ : precision_(std::move(precision)),
enable_npu_fast_compile_(enable_npu_fast_compile),
num_of_threads_(num_of_threads),
cache_dir_(std::move(cache_dir)),
- model_priority_(model_priority),
+ model_priority_(std::move(model_priority)),
num_streams_(num_streams),
context_(context),
enable_opencl_throttling_(enable_opencl_throttling),
disable_dynamic_shapes_(disable_dynamic_shapes),
export_ep_ctx_blob_(export_ep_ctx_blob),
enable_qdq_optimizer_(enable_qdq_optimizer),
- disable_cpu_fallback_(disable_cpu_fallback) {
+ disable_cpu_fallback_(disable_cpu_fallback),
+ so_epctx_embed_mode_{so_epctx_embed_mode} {
std::set ov_supported_device_types = {"CPU", "GPU",
"GPU.0", "GPU.1", "NPU"};
+
+ OVDevices devices;
+ std::vector available_devices = devices.get_ov_devices();
+
+ for (auto& device : available_devices) {
+ if (ov_supported_device_types.find(device) == ov_supported_device_types.end()) {
+ ov_supported_device_types.emplace(device);
+ }
+ }
+
if (dev_type == "") {
LOGS_DEFAULT(INFO) << "[OpenVINO-EP]"
<< "No runtime device selection option provided.";
@@ -116,7 +145,7 @@ struct OpenVINOExecutionProviderInfo {
dev_type = DEVICE;
if (dev_type.find("HETERO") == 0 || dev_type.find("MULTI") == 0 || dev_type.find("AUTO") == 0) {
- std::vector devices = parseDevices(dev_type);
+ std::vector devices = parseDevices(dev_type, available_devices);
precision_ = "FP16";
if (devices[0] == "CPU") {
precision_ = "FP32";
@@ -127,7 +156,7 @@ struct OpenVINOExecutionProviderInfo {
} else if (ov_supported_device_types.find(dev_type) != ov_supported_device_types.end()) {
device_type_ = std::move(dev_type);
} else if (dev_type.find("HETERO") == 0 || dev_type.find("MULTI") == 0 || dev_type.find("AUTO") == 0) {
- std::vector devices = parseDevices(dev_type);
+ std::vector devices = parseDevices(dev_type, available_devices);
device_type_ = dev_type;
} else {
ORT_THROW("Invalid device string: " + dev_type);
diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc
index 45bba431741c..716a7cd93640 100644
--- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc
+++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc
@@ -14,7 +14,8 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory {
int num_streams, void* context,
bool enable_opencl_throttling, bool disable_dynamic_shapes,
bool export_ep_ctx_blob, bool enable_qdq_optimizer,
- bool disable_cpu_fallback)
+ bool disable_cpu_fallback,
+ bool so_epctx_embed_mode)
: precision_(precision),
enable_npu_fast_compile_(enable_npu_fast_compile),
num_of_threads_(num_of_threads),
@@ -25,10 +26,12 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory {
disable_dynamic_shapes_(disable_dynamic_shapes),
export_ep_ctx_blob_(export_ep_ctx_blob),
enable_qdq_optimizer_(enable_qdq_optimizer),
- disable_cpu_fallback_(disable_cpu_fallback) {
+ disable_cpu_fallback_(disable_cpu_fallback),
+ so_epctx_embed_mode_(so_epctx_embed_mode) {
device_type_ = (device_type == nullptr) ? "" : device_type;
cache_dir_ = (cache_dir == nullptr) ? "" : cache_dir;
}
+
~OpenVINOProviderFactory() override {
}
@@ -48,13 +51,15 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory {
bool export_ep_ctx_blob_;
bool enable_qdq_optimizer_;
bool disable_cpu_fallback_;
+ bool so_epctx_embed_mode_;
};
std::unique_ptr OpenVINOProviderFactory::CreateProvider() {
OpenVINOExecutionProviderInfo info(device_type_, precision_, enable_npu_fast_compile_, num_of_threads_,
cache_dir_, model_priority_, num_streams_, context_, enable_opencl_throttling_,
disable_dynamic_shapes_, export_ep_ctx_blob_, enable_qdq_optimizer_,
- disable_cpu_fallback_);
+ disable_cpu_fallback_,
+ so_epctx_embed_mode_);
return std::make_unique(info);
}
@@ -105,6 +110,8 @@ struct OpenVINO_Provider : Provider {
bool disable_cpu_fallback = false;
+ bool so_epctx_embed_mode = true;
+
if (provider_options_map.find("device_type") != provider_options_map.end()) {
device_type = provider_options_map.at("device_type").c_str();
@@ -113,6 +120,14 @@ struct OpenVINO_Provider : Provider {
std::set deprecated_device_types = {"CPU_FP32", "GPU_FP32",
"GPU.0_FP32", "GPU.1_FP32", "GPU_FP16",
"GPU.0_FP16", "GPU.1_FP16"};
+ OVDevices devices;
+ std::vector available_devices = devices.get_ov_devices();
+
+ for (auto& device : available_devices) {
+ if (ov_supported_device_types.find(device) == ov_supported_device_types.end()) {
+ ov_supported_device_types.emplace(device);
+ }
+ }
if (deprecated_device_types.find(device_type) != deprecated_device_types.end()) {
std::string deprecated_device = device_type;
int delimit = device_type.find("_");
@@ -128,8 +143,8 @@ struct OpenVINO_Provider : Provider {
(device_type.find("MULTI:") == 0) ||
(device_type.find("AUTO:") == 0))) {
ORT_THROW(
- "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. "
- "Select from 'CPU', 'GPU', 'GPU.0', 'GPU.1', 'NPU' or from"
+ "[ERROR] [OpenVINO] You have selected wrong configuration value for the key 'device_type'. "
+ "Select from 'CPU', 'GPU', 'NPU', 'GPU.x' where x = 0,1,2 and so on or from"
" HETERO/MULTI/AUTO options available. \n");
}
}
@@ -253,9 +268,8 @@ struct OpenVINO_Provider : Provider {
}
}
}
-
- if (provider_options_map.find("export_ep_ctx_blob") != provider_options_map.end()) {
- bool_flag = provider_options_map.at("export_ep_ctx_blob");
+ if (provider_options_map.find("so_export_ep_ctx_blob") != provider_options_map.end()) {
+ bool_flag = provider_options_map.at("so_export_ep_ctx_blob");
if (bool_flag == "true" || bool_flag == "True")
export_ep_ctx_blob = true;
else if (bool_flag == "false" || bool_flag == "False")
@@ -271,6 +285,23 @@ struct OpenVINO_Provider : Provider {
disable_cpu_fallback = false;
bool_flag = "";
}
+ if (provider_options_map.find("so_epctx_embed_mode") != provider_options_map.end()) {
+ bool_flag = provider_options_map.at("so_epctx_embed_mode");
+ if (bool_flag == "true" || bool_flag == "True")
+ so_epctx_embed_mode = true;
+ else if (bool_flag == "false" || bool_flag == "False")
+ so_epctx_embed_mode = false;
+ bool_flag = "";
+ }
+
+ if (provider_options_map.find("so_epctx_path") != provider_options_map.end()) {
+ // The path to dump epctx model is valid only when epctx is enabled.
+ // Overrides the cache_dir option to dump model cache files from OV.
+ if (export_ep_ctx_blob) {
+ cache_dir = provider_options_map.at("so_epctx_path").c_str();
+ }
+ }
+
return std::make_shared(const_cast(device_type.c_str()),
const_cast(precision.c_str()),
enable_npu_fast_compile,
@@ -283,7 +314,8 @@ struct OpenVINO_Provider : Provider {
disable_dynamic_shapes,
export_ep_ctx_blob,
enable_qdq_optimizer,
- disable_cpu_fallback);
+ disable_cpu_fallback,
+ so_epctx_embed_mode);
}
void Initialize() override {
diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc
index 8dd00857b7dd..7e8681d304ab 100644
--- a/onnxruntime/core/providers/openvino/ov_interface.cc
+++ b/onnxruntime/core/providers/openvino/ov_interface.cc
@@ -63,7 +63,6 @@ std::shared_ptr OVCore::ReadModel(const std::string& model, const std
return FE->convert(inputModel);
} else {
ORT_THROW(log_tag + "[OpenVINO-EP] Unknown exception while Reading network");
- return NULL;
}
} catch (const Exception& e) {
ORT_THROW(log_tag + "[OpenVINO-EP] Exception while Reading network: " + std::string(e.what()));
@@ -73,9 +72,9 @@ std::shared_ptr OVCore::ReadModel(const std::string& model, const std
}
OVExeNetwork OVCore::CompileModel(std::shared_ptr& ie_cnn_network,
- std::string hw_target,
- const ov::AnyMap& device_config,
- std::string name) {
+ std::string& hw_target,
+ ov::AnyMap& device_config,
+ const std::string& name) {
ov::CompiledModel obj;
try {
obj = oe.compile_model(ie_cnn_network, hw_target, device_config);
@@ -92,22 +91,12 @@ OVExeNetwork OVCore::CompileModel(std::shared_ptr& ie_cnn_netwo
}
OVExeNetwork OVCore::CompileModel(const std::string& onnx_model,
- std::string hw_target,
- std::string precision,
- std::string cache_dir,
- const ov::AnyMap& device_config,
- std::string name) {
+ std::string& hw_target,
+ ov::AnyMap& device_config,
+ const std::string& name) {
ov::CompiledModel obj;
try {
- if (hw_target == "AUTO:GPU,CPU") {
- obj = oe.compile_model(onnx_model, ov::Tensor(),
- "AUTO",
- ov::device::priorities("GPU", "CPU"),
- ov::device::properties("GPU", {ov::cache_dir(cache_dir),
- ov::hint::inference_precision(precision)}));
- } else {
- obj = oe.compile_model(onnx_model, ov::Tensor(), hw_target, device_config);
- }
+ obj = oe.compile_model(onnx_model, ov::Tensor(), hw_target, device_config);
#ifndef NDEBUG
printDebugInfo(obj);
#endif
@@ -123,9 +112,19 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model,
OVExeNetwork OVCore::ImportModel(std::shared_ptr model_stream,
std::string hw_target,
const ov::AnyMap& device_config,
+ bool embed_mode,
std::string name) {
try {
- auto obj = oe.import_model(*model_stream, hw_target, device_config);
+ ov::CompiledModel obj;
+ if (embed_mode) {
+ obj = oe.import_model(*model_stream, hw_target, device_config);
+ } else {
+ std::string blob_file_path = (*model_stream).str();
+ std::ifstream modelStream(blob_file_path, std::ios_base::binary | std::ios_base::in);
+ obj = oe.import_model(modelStream,
+ hw_target,
+ {});
+ }
#ifndef NDEBUG
printDebugInfo(obj);
#endif
@@ -138,10 +137,8 @@ OVExeNetwork OVCore::ImportModel(std::shared_ptr model_strea
}
}
-void OVCore::SetCache(std::string cache_dir_path, std::string device_type) {
- if (device_type != "AUTO:GPU,CPU") {
- oe.set_property(ov::cache_dir(cache_dir_path));
- }
+void OVCore::SetCache(const std::string& cache_dir_path) {
+ oe.set_property(ov::cache_dir(cache_dir_path));
}
#ifdef IO_BUFFER_ENABLED
diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h
index af6f252feb2c..fa22e0f3cb03 100644
--- a/onnxruntime/core/providers/openvino/ov_interface.h
+++ b/onnxruntime/core/providers/openvino/ov_interface.h
@@ -40,20 +40,23 @@ class OVCore {
ov::Core oe;
public:
+ // OV Interface For Reading Model
std::shared_ptr ReadModel(const std::string& model_stream, const std::string& model_path) const;
+ // OV Interface for Compiling OV Model Type
OVExeNetwork CompileModel(std::shared_ptr& ie_cnn_network,
- std::string hw_target,
- const ov::AnyMap& device_config,
- std::string name);
+ std::string& hw_target,
+ ov::AnyMap& device_config,
+ const std::string& name);
+ // OV Interface for Fast Compile
OVExeNetwork CompileModel(const std::string& onnx_model,
- std::string hw_target,
- std::string precision,
- std::string cache_dir,
- const ov::AnyMap& device_config,
- std::string name);
+ std::string& hw_target,
+ ov::AnyMap& device_config,
+ const std::string& name);
+ // OV Interface for Import model Stream
OVExeNetwork ImportModel(std::shared_ptr model_stream,
std::string hw_target,
const ov::AnyMap& device_config,
+ bool embed_mode,
std::string name);
#ifdef IO_BUFFER_ENABLED
OVExeNetwork CompileModel(std::shared_ptr& model,
@@ -64,7 +67,7 @@ class OVCore {
std::string name);
#endif
std::vector GetAvailableDevices();
- void SetCache(std::string cache_dir_path, std::string device_type);
+ void SetCache(const std::string& cache_dir_path);
ov::Core& Get() { return oe; }
void SetStreams(const std::string& device_type, int num_streams);
};
diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc
index 856b97a0896d..3fcaff4369c8 100644
--- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc
+++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc
@@ -35,18 +35,16 @@ GetCapability::GetCapability(const GraphViewer& graph_viewer_param,
device_type_ = "CPU";
if (enable_qdq_optimizer) npu_qdq_optimizer_enabled = true;
}
-#if OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 1
- data_ops_ = new DataOps(graph_viewer_, V_2023_1, device_type_, npu_qdq_optimizer_enabled);
-#elif OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 2
- data_ops_ = new DataOps(graph_viewer_, V_2023_2, device_type_, npu_qdq_optimizer_enabled);
-#elif OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 3
- data_ops_ = new DataOps(graph_viewer_, V_2023_3, device_type_, npu_qdq_optimizer_enabled);
-#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 0
+#if OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 0
data_ops_ = new DataOps(graph_viewer_, V_2024_0, device_type_, npu_qdq_optimizer_enabled);
#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 1
data_ops_ = new DataOps(graph_viewer_, V_2024_1, device_type_, npu_qdq_optimizer_enabled);
+#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 2
+ data_ops_ = new DataOps(graph_viewer_, V_2024_2, device_type_, npu_qdq_optimizer_enabled);
+#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 3
+ data_ops_ = new DataOps(graph_viewer_, V_2024_3, device_type_, npu_qdq_optimizer_enabled);
#else
- data_ops_ = new DataOps(graph_viewer_, V_2024_1, device_type_, npu_qdq_optimizer_enabled);
+ data_ops_ = new DataOps(graph_viewer_, V_2024_3, device_type_, npu_qdq_optimizer_enabled);
#endif
}
diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc
index 38c029faff9d..d9aa13ec1bba 100644
--- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc
+++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc
@@ -142,6 +142,7 @@ std::vector supported_op_mode = {
{"GreaterOrEqual", V_2022_1, {"CPU", "GPU"}},
{"GridSample", V_2022_3, {"CPU"}},
{"GridSample", V_2023_0, {"GPU"}},
+ {"GRU", V_2024_1, {"CPU", "GPU"}},
{"HardMax", V_2023_1, {"CPU", "GPU"}},
{"Identity", V_2020_4, {"CPU", "GPU"}},
{"If", V_2022_3, {"CPU", "GPU"}},
@@ -155,6 +156,7 @@ std::vector supported_op_mode = {
{"LessOrEqual", V_2022_1, {"CPU", "GPU"}},
{"Log", V_2020_4, {"CPU", "GPU"}},
{"LogSoftMax", V_2022_1, {"CPU", "GPU"}},
+ {"LogSoftmax", V_2024_1, {"CPU", "GPU"}},
{"Loop", V_2021_4, {"CPU", "GPU"}},
{"LpNormalization", V_2023_1, {"CPU", "GPU"}},
{"LRN", V_2020_4, {"CPU", "GPU"}},
@@ -361,7 +363,7 @@ void DataOps::populate_op_mode_supported() {
// populate unsupportedmode_t
{
- UnsupportedOpMode obj = {{V_2024_1},
+ UnsupportedOpMode obj = {{V_2024_1, V_2024_2, V_2024_3},
[this](const Node* node, const InitializedTensorSet&) {
// If the Input of ReduceMax op is UINT8, it is rejected (Due to output mismatch)
for (size_t i = 0; i < node->InputDefs().size(); i++) {
@@ -376,7 +378,7 @@ void DataOps::populate_op_mode_supported() {
op_list_.insert({"ReduceMax", obj});
}
{
- UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1},
+ UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3},
[this](const Node* node, const InitializedTensorSet&) {
const auto& input_arg = node->InputDefs()[1];
auto shape = input_arg->Shape();
@@ -393,7 +395,7 @@ void DataOps::populate_op_mode_supported() {
op_list_.insert({"Reshape", obj});
}
{
- UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1},
+ UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3},
[this](const Node* node, const InitializedTensorSet&) {
// If the operator is unsqueeze
// If axes is an input, then we cannot produce a static graph.
@@ -408,7 +410,7 @@ void DataOps::populate_op_mode_supported() {
op_list_.insert({"Unsqueeze", obj});
}
{
- UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1},
+ UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3},
[this](const Node* node, const InitializedTensorSet&) {
// check for attributes
auto& upsample_attr = node->GetAttributes();
diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h
index 7cfb0516b8cc..4c064b08405c 100644
--- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h
+++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h
@@ -28,7 +28,9 @@ enum versionNum {
V_2023_2,
V_2023_3,
V_2024_0,
- V_2024_1
+ V_2024_1,
+ V_2024_2,
+ V_2024_3
};
using VersionNum = enum versionNum;
diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc
index c7689a0be7e7..a2b3ed068235 100644
--- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc
+++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc
@@ -205,11 +205,11 @@ static bool IsConnectedQAConstantInitializer(const Node* dq_node, const onnxrunt
// Check required because in some cases, when a NodeUnit cannot be formed with this standalone DQ
// we still need to check if it feeds into a supported Op
-static bool DQFeedsASupportedOp(const Node* dq_node, const onnxruntime::GraphViewer& src_graph) {
+static bool DQFeedsASupportedOp(const Node* dq_node) {
if (!dq_node->GetOutputEdgesCount()) return false; // Only feeds the graph output, and not any node
const auto& target_node = *dq_node->OutputNodesBegin();
- const auto op_type = target_node.OpType();
+ const auto& op_type = target_node.OpType();
if (op_type == "Conv" || op_type == "MatMul") {
// Conv and MatMul always keeps int8 DQs except if the DQ is sandwiched between Softmax and Conv/MatMul
@@ -219,8 +219,8 @@ static bool DQFeedsASupportedOp(const Node* dq_node, const onnxruntime::GraphVie
return true;
}
} else if (op_type == "Add") {
- // Add keeps all DQs except if it has const inits
- return !IsAnyDQAConstantInitializer(&target_node, src_graph);
+ // Add => keeps all DQs
+ return true;
}
return false;
}
@@ -291,7 +291,7 @@ static bool CheckDQRuleSet(const NodeUnit& node_unit,
const onnxruntime::GraphViewer& src_graph,
SkipReason& reason) {
const auto& target_node = node_unit.GetNode();
- auto op_type = node_unit.OpType();
+ const auto& op_type = node_unit.OpType();
// #1 Reverse DQ duplication
if (dq_node->Name().find(DuplicateDQ) != std::string::npos) {
@@ -337,6 +337,18 @@ static bool CheckDQRuleSet(const NodeUnit& node_unit,
}
}
+static bool CheckQFeedsIntoQuantizedOutput(const NodeUnit& node_unit,
+ const std::unordered_map graph_op_data_type) {
+ auto op_of_quantized_layer = node_unit.Outputs();
+ for (auto& itr : op_of_quantized_layer) {
+ auto it = graph_op_data_type.find(itr.node_arg.Name());
+ if (it != graph_op_data_type.end() && it->second == "tensor(uint8)") {
+ return true;
+ }
+ }
+ return false;
+}
+
static bool CheckQRuleSet(const NodeUnit& node_unit,
const Node* q_node,
const onnxruntime::GraphViewer& src_graph,
@@ -345,7 +357,13 @@ static bool CheckQRuleSet(const NodeUnit& node_unit,
// This Q should also be uint8
const auto& target_node = node_unit.GetNode();
- auto op_type = node_unit.OpType();
+ const auto& op_type = node_unit.OpType();
+
+ auto op = src_graph.GetOutputs();
+ std::unordered_map graph_op_data_type;
+ for (auto& ops : op) {
+ graph_op_data_type[src_graph.GetNodeArg(ops->Name())->Name()] = ops->Type()->data();
+ }
// If UInt16 Q, don't keep it
if (GetQDQDataType(q_node) == DT_UINT16 || GetQDQDataType(q_node) == DT_INT16) {
@@ -359,6 +377,8 @@ static bool CheckQRuleSet(const NodeUnit& node_unit,
} else if (op_type == "Add") {
// Add keeps all Qs
return true;
+ } else if (CheckQFeedsIntoQuantizedOutput(node_unit, std::move(graph_op_data_type))) {
+ return true;
} else {
// Keep Q of an unsupported Op only if the target that succeeds it is a supported Op in this list
return IsNextTargetNodeOfQValid(q_node, &target_node, src_graph, {"Conv", "Add", "MatMul"}, false);
@@ -469,7 +489,7 @@ static void AddStandaloneNodeUnit(onnxruntime::Graph& dst_graph, const onnxrunti
add_identity_op(true);
else if (IsConnectedQPresent(src_graph, dst_graph.Nodes(), &node_unit.GetNode(), node_unit.GetNode().InputDefs()))
AddNode(initializers_to_keep, src_graph, dst_graph, node_unit.GetNode());
- else if (DQFeedsASupportedOp(&node_unit.GetNode(), src_graph))
+ else if (DQFeedsASupportedOp(&node_unit.GetNode()))
AddNode(initializers_to_keep, src_graph, dst_graph, node_unit.GetNode());
else
add_identity_op(false);
@@ -543,7 +563,7 @@ static void AddQDQNodeUnit(onnxruntime::Graph& dst_graph,
// Add Node args for inputs
for (const auto& node_unit_input : node_unit_inputs) {
- auto node_arg_name = node_unit_input.node_arg.Name();
+ const auto& node_arg_name = node_unit_input.node_arg.Name();
if (auto dq_node_arg = dq_node_args_to_keep.find(node_arg_name); dq_node_arg != dq_node_args_to_keep.end()) {
// Add supported DQ as an input arg for the target node
input_args.push_back(dq_node_arg->second);
diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc
index 1d21933e9cba..924158a26b92 100644
--- a/onnxruntime/core/session/provider_bridge_ort.cc
+++ b/onnxruntime/core/session/provider_bridge_ort.cc
@@ -1931,12 +1931,31 @@ void ORTSessionOptionsToOrtOpenVINOProviderOptions(ProviderOptions& ov_options,
kOrtSessionOptionsDisableCPUEPFallback, "0") == "1";
if (disable_cpu_fallback)
ov_options["disable_cpu_fallback"] = "true";
+
+ // values from session options will override the providerOptions Value
+ bool so_epctx_enable = session_options->config_options.GetConfigOrDefault(
+ kOrtSessionOptionEpContextEnable, "0") == "1";
+ if (so_epctx_enable)
+ ov_options["so_export_ep_ctx_blob"] = "true";
+
+ std::string so_cache_path = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "").c_str();
+ ov_options["so_epctx_path"] = so_cache_path;
+
+ // Default embedMode is 1. Saving the compiled model contents as a Epctx node attribute
+ bool so_epctx_embed_mode = session_options->config_options.GetConfigOrDefault(
+ kOrtSessionOptionEpContextEmbedMode, "1") == "0";
+ if (so_epctx_embed_mode) {
+ // defaults to true
+ ov_options["so_epctx_embed_mode"] = "false";
+ }
}
std::shared_ptr OpenVINOProviderFactoryCreator::Create(ProviderOptions* provider_options_map,
const SessionOptions* session_options) {
- if (session_options)
+ // Append session options applicable for EP to EP Provider options.
+ if (session_options) {
onnxruntime::ORTSessionOptionsToOrtOpenVINOProviderOptions(*provider_options_map, session_options);
+ }
return s_library_openvino.Get().CreateExecutionProviderFactory(provider_options_map);
}
diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.cc b/onnxruntime/test/optimizer/graph_transform_test_builder.cc
index 2cbfbbb31764..03a71868a3dc 100644
--- a/onnxruntime/test/optimizer/graph_transform_test_builder.cc
+++ b/onnxruntime/test/optimizer/graph_transform_test_builder.cc
@@ -246,14 +246,14 @@ Status TestGraphTransformer(const std::function&
ORT_RETURN_IF_ERROR(pre_graph_checker(graph));
}
#if SAVE_TEST_GRAPH
- ORT_RETURN_IF_ERROR(Model::Save(model, "model_original.onnx"));
+ ORT_RETURN_IF_ERROR(Model::Save(model, ToPathString("model_original.onnx")));
#endif
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, level, logger));
if (post_graph_checker) {
ORT_RETURN_IF_ERROR(post_graph_checker(graph));
}
#if SAVE_TEST_GRAPH
- ORT_RETURN_IF_ERROR(Model::Save(model, "model_optimized.onnx"));
+ ORT_RETURN_IF_ERROR(Model::Save(model, ToPathString("model_optimized.onnx")));
#endif
};
diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc
index 14c5b60d6e0b..fb85eb4c29bb 100644
--- a/onnxruntime/test/optimizer/qdq_transformer_test.cc
+++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc
@@ -12,6 +12,7 @@
#include "core/mlas/inc/mlas.h"
#include "core/optimizer/double_qdq_pairs_remover.h"
#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h"
+#include "core/optimizer/qdq_transformer/qdq_propagation.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
@@ -3084,6 +3085,57 @@ TEST(QDQTransformerTests, QDQPropagation_QBackward) {
#endif
}
+// Test backwards propagation of a QuantizeLinear node that uses the "output_dtype" attribute
+// to set the quantization type (i.e., does not have an explicit zero-point input). This tests
+// the copying of attributes for QDQ propagation.
+TEST(QDQTransformerTests, QDQPropagation_QBackward_NoZP_OutputDtypeAttribute) {
+ auto test_case = [&](ONNX_NAMESPACE::TensorProto_DataType q_output_type) {
+ auto build_test_case = [&](ModelTestBuilder& builder) {
+ auto* input_arg = builder.MakeInput({1, 2, 2}, {-2.0f, 0.0f, 1.0f, 2.0f});
+ auto* output_arg = builder.MakeOutput();
+
+ // add Add
+ auto* const_1_input = builder.MakeScalarInitializer(1.0f);
+ auto* add_output = builder.MakeIntermediate();
+ builder.AddNode("Add", {input_arg, const_1_input}, {add_output});
+
+ // add Transpose
+ auto* transpose_output = builder.MakeIntermediate();
+ builder.AddNode("Transpose", {add_output}, {transpose_output});
+
+ // add Q with a "output_dtype" attribute. Omit the zero-point input (defaults to 0).
+ constexpr float qdq_scale = 1.0f;
+ Node& q_node = builder.AddQuantizeLinearNode(transpose_output, qdq_scale, output_arg);
+ q_node.AddAttribute("output_dtype", static_cast(q_output_type));
+ };
+
+ auto check_graph = [&](InferenceSessionWrapper& session) {
+ const QDQOpKeys qdq_keys = GetQDQOpKeys(false);
+ std::vector expected_op_types_in_order = {
+ "Add",
+ qdq_keys.quantize_linear,
+ qdq_keys.dequantize_linear,
+ "Transpose",
+ qdq_keys.quantize_linear,
+ };
+
+ const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph(), true);
+ EXPECT_EQ(op_types_in_order, expected_op_types_in_order);
+ };
+
+ TransformerTester(build_test_case,
+ check_graph,
+ TransformerLevel::Default,
+ TransformerLevel::Level1,
+ 21); // Opset >= 21 supports the "output_dtype" attribute
+ };
+
+ test_case(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
+ test_case(ONNX_NAMESPACE::TensorProto_DataType_INT8);
+ test_case(ONNX_NAMESPACE::TensorProto_DataType_UINT16);
+ test_case(ONNX_NAMESPACE::TensorProto_DataType_INT16);
+}
+
TEST(QDQTransformerTests, QDQPropagation_DQForward) {
auto test_case = [&](const std::vector& input_shape,
size_t maxpool_dim,
@@ -3420,6 +3472,122 @@ TEST(QDQTransformerTests, QDQPropagation_DQ_Q) {
#endif
}
+// Test propagating a DQ forward through a chain of Slice and Transpose operators that have multiple consumers.
+// original model:
+// in0 -> DQ -> Slice --+--> slice_out
+// |
+// +--> Add -> out0
+// |
+// +--> Transpose --+--> Pow -> out1
+// | |
+// | +--> Pow -> out2
+// |
+// +--> Transpose --+--> Pow -> out3
+// |
+// +--> Pow -> out4
+// expected model:
+// in0 -> DQ -> Slice -> Q --+--> DQ -> slice_out
+// |
+// +--> DQ -> Add -> out0
+// |
+// +--> DQ -> TP -> Q --+--> DQ -> Pow -> out1
+// | |
+// | +--> DQ -> Pow -> out2
+// |
+// +--> DQ -> TP -> Q --+--> DQ -> Pow -> out3
+// |
+// +--> DQ -> Pow -> out4
+TEST(QDQTransformerTests, QDQPropagation_DQForward_SliceMultipleConsumers) {
+ auto run_test_case = [&](bool slice_has_graph_output) {
+ auto build_test_case = [&](ModelTestBuilder& builder) {
+ std::vector input0_shape = {1, 2, 2, 2};
+ std::vector input1_shape = {1, 1, 1, 1};
+ auto* input0_arg = builder.MakeInput(input0_shape,
+ std::numeric_limits::min(),
+ std::numeric_limits::max());
+ auto* input1_arg = builder.MakeInput(input1_shape, {0.0f});
+ auto* output0_arg = builder.MakeOutput();
+ auto* output1_arg = builder.MakeOutput();
+ auto* output2_arg = builder.MakeOutput();
+ auto* output3_arg = builder.MakeOutput();
+ auto* output4_arg = builder.MakeOutput();
+
+ // DQ
+ constexpr float qdq_scale = 1.0f;
+ constexpr uint8_t qdq_zero_point = 128;
+ auto* dq_output = builder.MakeIntermediate();
+ builder.AddDequantizeLinearNode(input0_arg, qdq_scale, qdq_zero_point, dq_output);
+
+ // Slice
+ auto* slice_output = slice_has_graph_output ? builder.MakeOutput() : builder.MakeIntermediate();
+ auto* slice_starts = builder.Make1DInitializer(std::vector{0, 0, 0, 0});
+ auto* slice_ends = builder.Make1DInitializer(std::vector{1, 1, 1, 1});
+ builder.AddNode("Slice", {dq_output, slice_starts, slice_ends}, {slice_output});
+
+ // Add
+ builder.AddNode("Add", {slice_output, input1_arg}, {output0_arg});
+
+ // Transpose
+ auto* transpose0_output = builder.MakeIntermediate();
+ builder.AddNode("Transpose", {slice_output}, {transpose0_output});
+
+ // Transpose
+ auto* transpose1_output = builder.MakeIntermediate();
+ builder.AddNode("Transpose", {slice_output}, {transpose1_output});
+
+ // Pows
+ auto* pow_exp = builder.MakeScalarInitializer(2.0f);
+ builder.AddNode("Pow", {transpose0_output, pow_exp}, {output1_arg});
+ builder.AddNode("Pow", {transpose0_output, pow_exp}, {output2_arg});
+ builder.AddNode("Pow", {transpose1_output, pow_exp}, {output3_arg});
+ builder.AddNode("Pow", {transpose1_output, pow_exp}, {output4_arg});
+ };
+
+ auto check_graph = [&](InferenceSessionWrapper& session) {
+ const QDQOpKeys qdq_keys = GetQDQOpKeys(false);
+ std::vector expected_op_types_in_order;
+ expected_op_types_in_order.reserve(20);
+ expected_op_types_in_order.insert(expected_op_types_in_order.end(),
+ {qdq_keys.dequantize_linear,
+ "Slice",
+ qdq_keys.quantize_linear});
+
+ if (slice_has_graph_output) {
+ // Should have a DQ before the graph output generated by the Slice.
+ expected_op_types_in_order.push_back(qdq_keys.dequantize_linear);
+ }
+
+ expected_op_types_in_order.insert(expected_op_types_in_order.end(),
+ {qdq_keys.dequantize_linear,
+ "Add",
+ qdq_keys.dequantize_linear,
+ "Transpose",
+ qdq_keys.quantize_linear, qdq_keys.dequantize_linear,
+ "Pow",
+ qdq_keys.dequantize_linear,
+ "Pow",
+ qdq_keys.dequantize_linear,
+ "Transpose",
+ qdq_keys.quantize_linear, qdq_keys.dequantize_linear,
+ "Pow",
+ qdq_keys.dequantize_linear,
+ "Pow"});
+
+ const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph(), true);
+ EXPECT_EQ(op_types_in_order, expected_op_types_in_order);
+ };
+
+ TransformerTester(build_test_case,
+ check_graph,
+ TransformerLevel::Default,
+ TransformerLevel::Level1,
+ 18, 0.0, 0.0, std::make_unique());
+ };
+
+ run_test_case(/*slice_has_graph_output*/ false);
+ run_test_case(/*slice_has_graph_output*/ true);
+}
+
TEST(QDQTransformerTests, QDQ_Selector_Test) {
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/transform/qdq_conv.onnx");
diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc
index e6d4e0a94abd..84c3bc16346f 100644
--- a/onnxruntime/test/perftest/command_args_parser.cc
+++ b/onnxruntime/test/perftest/command_args_parser.cc
@@ -253,7 +253,6 @@ static bool ParseSessionConfigs(const std::string& configs_string,
test_config.machine_config.provider_type_name = onnxruntime::kDnnlExecutionProvider;
} else if (!CompareCString(optarg, ORT_TSTR("openvino"))) {
test_config.machine_config.provider_type_name = onnxruntime::kOpenVINOExecutionProvider;
- test_config.run_config.optimization_level = ORT_DISABLE_ALL;
} else if (!CompareCString(optarg, ORT_TSTR("tensorrt"))) {
test_config.machine_config.provider_type_name = onnxruntime::kTensorrtExecutionProvider;
} else if (!CompareCString(optarg, ORT_TSTR("qnn"))) {
diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc
index 72b5da7aaec9..fc1bdb10d745 100644
--- a/onnxruntime/test/perftest/ort_test_session.cc
+++ b/onnxruntime/test/perftest/ort_test_session.cc
@@ -699,6 +699,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
std::set deprecated_device_types = {"CPU_FP32", "GPU_FP32",
"GPU.0_FP32", "GPU.1_FP32", "GPU_FP16",
"GPU.0_FP16", "GPU.1_FP16"};
+ size_t num_gpus = 10;
+ for (size_t i = 0; i <= num_gpus; i++) {
+ ov_supported_device_types.emplace("GPU." + std::to_string(i));
+ }
if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) {
ov_options[key] = value;
} else if (deprecated_device_types.find(value) != deprecated_device_types.end()) {
diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc
index d0e08448ce45..5f332ddcddb8 100644
--- a/onnxruntime/test/providers/checkers.cc
+++ b/onnxruntime/test/providers/checkers.cc
@@ -25,7 +25,15 @@ struct DefaultTolerance {
static constexpr float relative = 1e-5f;
// Allow to have different default absolute tolerance for different providers.
- static float get_absolute(const std::string& /*provider_type*/) {
+ static float get_absolute(const std::string& provider_type /*provider_type*/) {
+ if (provider_type == kOpenVINOExecutionProvider) {
+#ifdef OPENVINO_CONFIG_NPU
+ return 0.005f;
+#else
+ return absolute;
+#endif
+ }
+
return absolute;
}
};
@@ -40,7 +48,15 @@ struct DefaultTolerance {
static constexpr float relative = 1e-4f;
- static float get_absolute(const std::string& /*provider_type*/) {
+ static float get_absolute(const std::string& provider_type /*provider_type*/) {
+ if (provider_type == kOpenVINOExecutionProvider) {
+#ifdef OPENVINO_CONFIG_NPU
+ return 0.005f;
+#else
+ return absolute;
+#endif
+ }
+
return absolute;
}
};
diff --git a/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc b/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc
index b05649dafc18..30960e71c577 100644
--- a/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc
+++ b/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc
@@ -98,8 +98,12 @@ static void RunGruTest(const std::vector& X_data,
test.AddOptionalOutputEdge();
}
- // TensorRT failed on GRU tests
+// TensorRT, OpenVINO failed on GRU tests
+#if defined(USE_OPENVINO)
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+#else
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+#endif
}
void DefaultActivationsSimpleWeightsNoBias(std::string direction,
diff --git a/onnxruntime/test/python/onnxruntime_test_python_backend_mlops.py b/onnxruntime/test/python/onnxruntime_test_python_backend_mlops.py
index b5400b487cfc..c245699e211d 100644
--- a/onnxruntime/test/python/onnxruntime_test_python_backend_mlops.py
+++ b/onnxruntime/test/python/onnxruntime_test_python_backend_mlops.py
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-# -*- coding: UTF-8 -*-
import unittest
import numpy as np
@@ -10,7 +9,7 @@
import onnxruntime.backend as backend
from onnxruntime import datasets
-from onnxruntime.backend.backend import OnnxRuntimeBackend as ort_backend # noqa: N813
+from onnxruntime.backend.backend import OnnxRuntimeBackend as ort_backend
def check_list_of_map_to_float(testcase, expected_rows, actual_rows):
diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc
index 8d110c692751..1135ef41cfc4 100644
--- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc
+++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc
@@ -67,410 +67,422 @@ using OpsetToIgnorableIndicesMap = InlinedHashMap;
* or not.
* 3. Some ops are not supported in older opsets, we need to check whether it is applicable to recompute or not.
*/
-const InlinedHashMap& GetAllowedRecomputeOps(int probe_op_level) {
- static InlinedHashMap> recomputable_op_table_map;
- if (recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end()) {
- return recomputable_op_table_map.at(probe_op_level);
- }
+InlinedHashMap> InitializeRecomputableOpTable() {
+ InlinedHashMap> recomputable_op_table_map;
+
+ constexpr const int basic_op_level = static_cast(ProbeLevel::Basic);
+ recomputable_op_table_map.insert({basic_op_level, InlinedHashMap()});
+ auto& basic_recomputable_op_table = recomputable_op_table_map.at(basic_op_level);
+
+ basic_recomputable_op_table.insert({
+ {
+ utils::GetFullQualifiedOpName("Add", kOnnxDomain),
+ {
+ {1, {}},
+ {6, {}},
+ {7, {}},
+ {13, {}},
+ {14, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("BatchNormalization", kOnnxDomain),
+ {
+ {1, {}},
+ {6, {}},
+ {7, {}},
+ {9, {}},
+ {14, {}},
+ {15, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("BiasGelu", kMSDomain),
+ {
+ {1, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("BiasDropout", kMSDomain),
+ {
+ {1, {3, 4}}, // ignore ratio (optional) and training mode (optional)
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain),
+ {
+ {1, {3, 4}}, // ignore ratio (optional) and training mode (optional)
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("BitmaskDropout", kMSDomain),
+ {
+ {1, {1, 2}}, // ignore ratio (optional) and training mode (optional)
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Cast", kOnnxDomain),
+ {
+ {1, {}},
+ {6, {}},
+ {9, {}},
+ {13, {}},
+ {19, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("ConcatTraining", kMSDomain),
+ {
+ {1, {}},
+
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("ConstantOfShape", kOnnxDomain),
+ {
+ {9, {0}}, // ignore the `input`, e.g. the shape of the expected output tensor
+ {20, {0}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Cos", kOnnxDomain),
+ {
+ {7, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("CumSum", kOnnxDomain),
+ {
+ // The axis input is trivial
+ {11, {1}},
+ {14, {1}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Dropout", kOnnxDomain),
+ {
+ // ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support.
+ {12, {1, 2}}, // ignore ratio and training_mode
+ {13, {1, 2}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Div", kOnnxDomain),
+ {
+ {1, {}},
+ {6, {}},
+ {7, {}},
+ {13, {}},
+ {14, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Einsum", kOnnxDomain),
+ {
+ {12, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Equal", kOnnxDomain),
+ {
+ {1, {}},
+ {7, {}},
+ {11, {}},
+ {13, {}},
+ {19, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Expand", kOnnxDomain),
+ {
+ {8, {1}}, // Ignore the shape.
+ {13, {1}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("FastGelu", kMSDomain),
+ {
+ {1, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("FlattenAndUnpad", kMSDomain),
+ {
+ {1, {1}}, // ignore the indices
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Gather", kOnnxDomain),
+ {
+ {1, {1}}, // ignore the indices
+ {11, {1}},
+ {13, {1}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Gelu", kOnnxDomain),
+ {
+ {20, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Gelu", kMSDomain),
+ {
+ {1, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Gemm", kOnnxDomain),
+ {
+ {1, {}},
+ {6, {}},
+ {7, {}},
+ {9, {}},
+ {11, {}},
+ {13, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Less", kOnnxDomain),
+ {
+ {1, {}},
+ {7, {}},
+ {9, {}},
+ {13, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("MemcpyFromHost", kOnnxDomain),
+ {
+ {1, {0}}, // Ignore CPU input.
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Mul", kOnnxDomain),
+ {
+ {1, {}},
+ {6, {}},
+ {7, {}},
+ {13, {}},
+ {14, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Neg", kOnnxDomain),
+ {
+ {1, {}},
+ {6, {}},
+ {13, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("NonZero", kOnnxDomain),
+ {
+ {9, {}},
+ {13, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("PadAndUnflatten", kMSDomain),
+ {
+ {1, {1, 2}}, // ignore the indices and unflatten_dims
+ },
+ },
+ {
+ // Be noted, NOT all PythonOp will be allowed to recompute, there will be further check.
+ utils::GetFullQualifiedOpName("PythonOp", kMSDomain),
+ {
+ {1, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Range", kOnnxDomain),
+ {
+ {11, {0, 1, 2}}, // ignore start, end, delta, because they are scalars.
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Reshape", kOnnxDomain),
+ {
+ {1, {}},
+ {5, {}}, // ignore the shape.
+ {13, {}},
+ {14, {}},
+ {19, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Sin", kOnnxDomain),
+ {
+ {7, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Slice", kOnnxDomain),
+ {
+ {1, {}},
+ {10, {1, 2, 3, 4}}, // ignore starts, ends, axes (optional) and steps (optional)
+ {11, {1, 2, 3, 4}},
+ {13, {1, 2, 3, 4}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Split", kOnnxDomain),
+ {
+ {1, {1}}, // ignore split (optional)
+ {2, {}},
+ {11, {}},
+ {13, {1}}, // ignore the split (optional)
+ {18, {1}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Squeeze", kOnnxDomain),
+ {
+ {1, {}},
+ {11, {}},
+ {13, {1}}, // ignore the axes (optional)
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Sub", kOnnxDomain),
+ {
+ {1, {}},
+ {6, {}},
+ {7, {}},
+ {13, {}},
+ {14, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Tile", kOnnxDomain),
+ {
+ {1, {1, 2}},
+ {6, {1}},
+ {13, {1}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Transpose", kOnnxDomain),
+ {
+ {1, {}},
+ {13, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Trilu", kOnnxDomain),
+ {
+ {14, {1}}, // ignore k (optional)
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("QuickGelu", kMSDomain),
+ {
+ {1, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Unsqueeze", kOnnxDomain),
+ {
+ {1, {}},
+ {11, {}},
+ {13, {1}}, // ignore the axes (optional)
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Where", kOnnxDomain),
+ {
+ {9, {}},
+ {16, {}},
+ },
+ },
+
+ });
+
+ constexpr const int advanced_op_level = static_cast(ProbeLevel::Advanced);
+ recomputable_op_table_map.insert({advanced_op_level, InlinedHashMap()});
+ auto& advanced_recomputable_op_table = recomputable_op_table_map.at(advanced_op_level);
+ // Append basic_recomputable_op_table to advanced_recomputable_op_table.
+ advanced_recomputable_op_table.insert(recomputable_op_table_map.at(basic_op_level).begin(),
+ recomputable_op_table_map.at(basic_op_level).end());
+
+ advanced_recomputable_op_table.insert({
+ {
+ utils::GetFullQualifiedOpName("BiasSoftmax", kMSDomain),
+ {
+ {1, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("BiasSoftmaxDropout", kMSDomain),
+ {
+ {1, {2}}, // ignore ratio (optional)
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
+ {
+ // Opset 1 in ONNX official does not have LayerNormalization,
+ // while our contrib op defined LayerNormalization in opset 1 in ONNX domain.
+ {1, {}},
+ {17, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("MatMul", kOnnxDomain),
+ {
+ {1, {}},
+ {9, {}},
+ {13, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("FusedMatMul", kMSDomain),
+ {
+ {1, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain),
+ {
+ // Opset 1 in ONNX official does not have SimplifiedLayerNormalization,
+ // while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain.
+ {1, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain),
+ {
+ {1, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain),
+ {
+ {1, {}},
+ },
+ },
+ {
+ utils::GetFullQualifiedOpName("Softmax", kOnnxDomain),
+ {
+ {1, {}},
+ {11, {}},
+ {13, {}},
+ },
+ },
+ });
+
+ return recomputable_op_table_map;
+}
- recomputable_op_table_map.insert({probe_op_level, InlinedHashMap()});
- auto& recomputable_op_table = recomputable_op_table_map.at(probe_op_level);
- if (probe_op_level >= static_cast(ProbeLevel::Basic)) {
- recomputable_op_table.insert({
- {
- utils::GetFullQualifiedOpName("Add", kOnnxDomain),
- {
- {1, {}},
- {6, {}},
- {7, {}},
- {13, {}},
- {14, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("BatchNormalization", kOnnxDomain),
- {
- {1, {}},
- {6, {}},
- {7, {}},
- {9, {}},
- {14, {}},
- {15, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("BiasGelu", kMSDomain),
- {
- {1, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("BiasDropout", kMSDomain),
- {
- {1, {3, 4}}, // ignore ratio (optional) and training mode (optional)
- },
- },
- {
- utils::GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain),
- {
- {1, {3, 4}}, // ignore ratio (optional) and training mode (optional)
- },
- },
- {
- utils::GetFullQualifiedOpName("BitmaskDropout", kMSDomain),
- {
- {1, {1, 2}}, // ignore ratio (optional) and training mode (optional)
- },
- },
- {
- utils::GetFullQualifiedOpName("Cast", kOnnxDomain),
- {
- {1, {}},
- {6, {}},
- {9, {}},
- {13, {}},
- {19, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("ConcatTraining", kMSDomain),
- {
- {1, {}},
-
- },
- },
- {
- utils::GetFullQualifiedOpName("ConstantOfShape", kOnnxDomain),
- {
- {9, {0}}, // ignore the `input`, e.g. the shape of the expected output tensor
- {20, {0}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Cos", kOnnxDomain),
- {
- {7, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("CumSum", kOnnxDomain),
- {
- // The axis input is trivial
- {11, {1}},
- {14, {1}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Dropout", kOnnxDomain),
- {
- // ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support.
- {12, {1, 2}}, // ignore ratio and training_mode
- {13, {1, 2}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Div", kOnnxDomain),
- {
- {1, {}},
- {6, {}},
- {7, {}},
- {13, {}},
- {14, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Einsum", kOnnxDomain),
- {
- {12, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Equal", kOnnxDomain),
- {
- {1, {}},
- {7, {}},
- {11, {}},
- {13, {}},
- {19, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Expand", kOnnxDomain),
- {
- {8, {1}}, // Ignore the shape.
- {13, {1}},
- },
- },
- {
- utils::GetFullQualifiedOpName("FastGelu", kMSDomain),
- {
- {1, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("FlattenAndUnpad", kMSDomain),
- {
- {1, {1}}, // ignore the indices
- },
- },
- {
- utils::GetFullQualifiedOpName("Gather", kOnnxDomain),
- {
- {1, {1}}, // ignore the indices
- {11, {1}},
- {13, {1}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Gelu", kOnnxDomain),
- {
- {20, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Gelu", kMSDomain),
- {
- {1, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Gemm", kOnnxDomain),
- {
- {1, {}},
- {6, {}},
- {7, {}},
- {9, {}},
- {11, {}},
- {13, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Less", kOnnxDomain),
- {
- {1, {}},
- {7, {}},
- {9, {}},
- {13, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("MemcpyFromHost", kOnnxDomain),
- {
- {1, {0}}, // Ignore CPU input.
- },
- },
- {
- utils::GetFullQualifiedOpName("Mul", kOnnxDomain),
- {
- {1, {}},
- {6, {}},
- {7, {}},
- {13, {}},
- {14, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Neg", kOnnxDomain),
- {
- {1, {}},
- {6, {}},
- {13, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("NonZero", kOnnxDomain),
- {
- {9, {}},
- {13, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("PadAndUnflatten", kMSDomain),
- {
- {1, {1, 2}}, // ignore the indices and unflatten_dims
- },
- },
- {
- // Be noted, NOT all PythonOp will be allowed to recompute, there will be further check.
- utils::GetFullQualifiedOpName("PythonOp", kMSDomain),
- {
- {1, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Range", kOnnxDomain),
- {
- {11, {0, 1, 2}}, // ignore start, end, delta, because they are scalars.
- },
- },
- {
- utils::GetFullQualifiedOpName("Reshape", kOnnxDomain),
- {
- {1, {}},
- {5, {}}, // ignore the shape.
- {13, {}},
- {14, {}},
- {19, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Sin", kOnnxDomain),
- {
- {7, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Slice", kOnnxDomain),
- {
- {1, {}},
- {10, {1, 2, 3, 4}}, // ignore starts, ends, axes (optional) and steps (optional)
- {11, {1, 2, 3, 4}},
- {13, {1, 2, 3, 4}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Split", kOnnxDomain),
- {
- {1, {1}}, // ignore split (optional)
- {2, {}},
- {11, {}},
- {13, {1}}, // ignore the split (optional)
- {18, {1}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Squeeze", kOnnxDomain),
- {
- {1, {}},
- {11, {}},
- {13, {1}}, // ignore the axes (optional)
- },
- },
- {
- utils::GetFullQualifiedOpName("Sub", kOnnxDomain),
- {
- {1, {}},
- {6, {}},
- {7, {}},
- {13, {}},
- {14, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Tile", kOnnxDomain),
- {
- {1, {1, 2}},
- {6, {1}},
- {13, {1}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Transpose", kOnnxDomain),
- {
- {1, {}},
- {13, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Trilu", kOnnxDomain),
- {
- {14, {1}}, // ignore k (optional)
- },
- },
- {
- utils::GetFullQualifiedOpName("QuickGelu", kMSDomain),
- {
- {1, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Unsqueeze", kOnnxDomain),
- {
- {1, {}},
- {11, {}},
- {13, {1}}, // ignore the axes (optional)
- },
- },
- {
- utils::GetFullQualifiedOpName("Where", kOnnxDomain),
- {
- {9, {}},
- {16, {}},
- },
- },
-
- });
- }
+const InlinedHashMap& GetAllowedRecomputeOps(int probe_op_level) {
+ static InlinedHashMap>
+ recomputable_op_table_map = InitializeRecomputableOpTable();
- if (probe_op_level >= static_cast(ProbeLevel::Advanced)) {
- recomputable_op_table.insert({
- {
- utils::GetFullQualifiedOpName("BiasSoftmax", kMSDomain),
- {
- {1, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("BiasSoftmaxDropout", kMSDomain),
- {
- {1, {2}}, // ignore ratio (optional)
- },
- },
- {
- utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
- {
- // Opset 1 in ONNX official does not have LayerNormalization,
- // while our contrib op defined LayerNormalization in opset 1 in ONNX domain.
- {1, {}},
- {17, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("MatMul", kOnnxDomain),
- {
- {1, {}},
- {9, {}},
- {13, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("FusedMatMul", kMSDomain),
- {
- {1, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain),
- {
- // Opset 1 in ONNX official does not have SimplifiedLayerNormalization,
- // while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain.
- {1, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain),
- {
- {1, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain),
- {
- {1, {}},
- },
- },
- {
- utils::GetFullQualifiedOpName("Softmax", kOnnxDomain),
- {
- {1, {}},
- {11, {}},
- {13, {}},
- },
- },
- });
- }
+ ORT_ENFORCE(recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end(),
+ "Cannot get recomputable op table, probe level: ", probe_op_level);
- return recomputable_op_table;
+ return recomputable_op_table_map.at(probe_op_level);
}
/**
diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc
index 90c97eed0c6d..be25eefb201d 100644
--- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc
+++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc
@@ -542,8 +542,9 @@ TEST(TrainingApiTest, OptimStep) {
std::string param_name = "fc2.weight";
// before training, check if optim state is initialized to 0
onnxruntime::training::api::OptimizerCheckpointState& optimizer_states = state.optimizer_checkpoint_state;
+ std::shared_ptr group0_states = optimizer_states.group_named_optimizer_states["group0"];
onnxruntime::training::api::ParameterOptimizerState& param_state =
- optimizer_states.group_named_optimizer_states["group0"]->param_named_optimizer_states.at(param_name);
+ group0_states->param_named_optimizer_states.at(param_name);
OrtValue& moment_1 = param_state.at("momentum0");
std::vector param_vec_before_optimizer_step;
diff --git a/orttraining/orttraining/training_api/checkpoint.cc b/orttraining/orttraining/training_api/checkpoint.cc
index 56029b34c24d..cbff1891b8c8 100644
--- a/orttraining/orttraining/training_api/checkpoint.cc
+++ b/orttraining/orttraining/training_api/checkpoint.cc
@@ -449,7 +449,7 @@ Status FromOptimizerState(const OptimizerCheckpointState& optimizer_state,
fbs_optimizer_groups.reserve(optimizer_state.group_named_optimizer_states.size());
for (const auto& group_name : SortedKeys(optimizer_state.group_named_optimizer_states)) {
- const std::shared_ptr& group_optimizer_state_ptr =
+ std::shared_ptr group_optimizer_state_ptr =
optimizer_state.group_named_optimizer_states.at(group_name);
std::vector> optimizer_states;
diff --git a/pyproject.toml b/pyproject.toml
index 1c3a719fb544..6429df2722b2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -77,6 +77,7 @@ ignore = [
"G004", # FIXME: Enable when the rule can be autofixed
"N803", # Argument casing
"N812", # Allow import torch.nn.functional as F
+ "N813", # Allow importing camelcase names in lowercase
"N999", # Module names
"NPY002", # np.random.Generator may not always fit our use cases
"PERF203", # "try-except-in-loop" only affects Python <3.11, and the improvement is minor; can have false positives