From 281ed8c12d2d2a3f5b683e6267aa0fca4d4add50 Mon Sep 17 00:00:00 2001 From: glen-amd <146770157+glen-amd@users.noreply.github.com> Date: Fri, 12 Jul 2024 21:22:58 -0700 Subject: [PATCH] VitisAI EP Context Model (#20926) # Why so many commits - Runtime debugging - which is necessary - Three different approaches to EP context model - as a result testing back and forth - Windows compatibility issues - this development has been done on Linux for convenience # "Open" (?) questions - Full offloading to a specific EP - Dumping EP context models by EPs vs [by ONNXRT](https://github.com/microsoft/onnxruntime/blob/e2abba18ea9370329ce6894a4eb3e98ad8f11cb6/onnxruntime/core/framework/graph_partitioner.cc#L725) - [Node name to pick nodes](https://github.com/microsoft/onnxruntime/blob/e2abba18ea9370329ce6894a4eb3e98ad8f11cb6/onnxruntime/core/framework/graph_partitioner.cc#L654) # VitisAI EP made three variant implementations that have respective pros and cons (and of course we can combine them) ## Serialize and cache the list of compute capabilities and the original ONNX model itself ## In `ComputeCapability()`, serialize and cache the backend compilation cache and the related necessary cache info such as cache dir and cache key ## In `Compile()`, serialize and cache the backend compilation cache and the related necessary cache info such as cache dir and cache key # EP context model creation - Precondition Session option configuration `kOrtSessionOptionEpContextEnable` (aka "ep.context_enable") is enabled. - Approach 1 - Steps 1. EP creates an ONNX model whose main graph has EP context nodes (i.e., node type is "EPContext"). 2. EP implements/overrides `IExecutionProvider::GetEpContextNodes()` method. 3. ONNXRT core creates an EP context model and saves/dumps it. - `CreateEpContextModel()` in the file "graph_partitioner.cc" - In `get_ep_context_node()`, `Node::Name()` is used to check whether a node is an EP context node. This limits that EP model creation can only happen in `IExecutionProvider::Compile()`. - The workaround is (1) not implementing `IExecutionProvider::GetEpContextNodes()` and (2) dumping the EP context model by EP itself. 4. Optionally, EP can also dump the EP context model it created by iteself. - Examples - `QNNExecutionProvider` - `VitisAIExecutionProvider` - Approach 2 - Steps 1. EP creates an ONNX model whose main graph has EP context nodes (i.e., node type is "EPContext"). 2. EP does NOT implement `IExecutionProvider::GetEpContextNodes()` at all. 3. EP dumps the EP context model it created. - Examples - `TensorrtExecutionProvider` - UPDATES - TRT EP is switching to leveraging `IExecutionProvider::GetEpContextNodes()` - `OpenVINOExecutionProvider` (?) # What to cache in EP context nodes - Non Compilation based EPs - Examples - `VitisAIExecutionProvider` - Characteristics - Heavy lifting work happens in `IExecutionProvider::GetCapability()`. - Preconditions - `IExecutionProvider::GetCapability()` is only called once by ONNXRT. - Cache content - Serialization of a list of `ComputeCapability` - Not EP-specific - Serialized using `onnx::FunctionProto` - EP-specific cache - Compilation based EPs - Examples - `QNNExecutionProvider` - `TensorrtExecutionProvider` - `MIGraphXExecutionProvider` - `OpenVINOExecutionProvider` - Cache content - EP-specific cache # Requirements - Offline / AOT compilation of ONNX models with EP context cache - Compile somewhere, run everywhere - Pseudo code with brief explanation ``` GenerateCache(original_onnx_file, cache_onnx_file) model_buffer = load(original_onnx_file) --> Load the original ONNX model file model_buffer = decrypt(model_buffer) session_options = { kOrtSessionOptionEpContextEnable: true, kOrtSessionOptionEpContextFilePath: temp_file } --> Set necessary configs Ort::CreateSessionFromArray(model_buffer, session_options) --> The new ONNX model with EP context is created and dumped into the user specified file "temp_file" temp_buffer = encrypt(temp_file) write(temp_buffer, cache_onnx_file) --> Write the encypted context of "temp_file" into the "cache_onnx_file" file InitializeInferenceSession(cache_onnx_file) model_buffer = load(cache_onnx_file) --> Load the ONNX model with EP context from the file generated in the previous step model_buffer = decrypt(model_buffer) session_options = { } Ort::CreateSessionFromArray(model_buffer, session_options) --> Create and initalize an session with the EP context model ``` - Python code with comments - EP context model creation ```python import onnxruntime as onnxrt # Session options for creating an ONNX model with EP context cache. sess_opts = onnxrt.SessionOptions() # Verbose. sess_opts.log_severity_level = 0 # This is REQUIRED. sess_opts.add_session_config_entry("ep.context_enable", "1") # This is OPTIONAL. # Either an absolute path (preferred for now) or a relative path (WIP) is okay. # sess_opts.add_session_config_entry("ep.context_file_path", "/some/path/to/original_model_ctx.onnx") # This is OPTIONAL. sess_opts.add_session_config_entry("ep.context_embed_mode", "1") orig_model_location = "/some/path/to/original_model.onnx" sess = onnxrt.InferenceSession(orig_model_location, sess_opts, providers=["VitisAIExecutionProvider"], provider_options=[]) ``` - Inference run with an EP context model ```python import onnxruntime as onnxrt # Session options for creating an ONNX model with EP context cache. sess_opts = onnxrt.SessionOptions() # Default EP context model path. # ep_ctx_model_location = "/some/path/to/origina_model.onnx_ctx.onnx" # User configured EP context model path. ep_ctx_model_location = "/some/path/to/origina_model_ctx.onnx" sess = onnxrt.InferenceSession(ep_ctx_model_location, sess_opts, providers=["VitisAIExecutionProvider"], provider_options=[]) model_inputs = {} run_opts = onnxrt.RunOptions() # Verbose. run_opts.log_severity_level = 1 sess.run(None, model_inputs, run_opts) ``` --------- Co-authored-by: Glen Cao --- include/onnxruntime/core/graph/basic_types.h | 2 + .../providers/shared_library/provider_api.h | 3 + .../shared_library/provider_interfaces.h | 72 ++ .../shared_library/provider_wrappedtypes.h | 87 +++ .../providers/vitisai/imp/ep_context_utils.cc | 682 ++++++++++++++++++ .../core/providers/vitisai/imp/global_api.cc | 25 +- .../vitisai/include/ep_context_utils.h | 81 +++ .../vitisai/include/vaip/global_api.h | 2 + .../vitisai/vitisai_execution_provider.cc | 144 +++- .../vitisai/vitisai_execution_provider.h | 28 +- .../vitisai/vitisai_provider_factory.cc | 0 .../core/session/provider_bridge_ort.cc | 82 ++- .../python/onnxruntime_pybind_state.cc | 3 + 13 files changed, 1195 insertions(+), 16 deletions(-) create mode 100644 onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc create mode 100644 onnxruntime/core/providers/vitisai/include/ep_context_utils.h mode change 100755 => 100644 onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc diff --git a/include/onnxruntime/core/graph/basic_types.h b/include/onnxruntime/core/graph/basic_types.h index 36984d0405bb..cdd5e4c1e571 100644 --- a/include/onnxruntime/core/graph/basic_types.h +++ b/include/onnxruntime/core/graph/basic_types.h @@ -19,6 +19,8 @@ class TensorProto; class SparseTensorProto; class TypeProto; class AttributeProto; +class FunctionProto; +class OperatorSetIdProto; // define types that would come from the ONNX library if we were building against it. #if defined(ORT_MINIMAL_BUILD) using OperatorSetVersion = int; diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 86e49627fe26..2f54a04e1530 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -108,6 +108,7 @@ struct NodeProto; struct SparseTensorProto; struct StringStringEntryProto; struct StringStringEntryProtos; // RepeatedPtrField +struct OperatorSetIdProto; struct TensorProto; struct TensorProtos; // RepeatedPtrField struct TensorShapeProto_Dimension; @@ -120,6 +121,7 @@ struct TypeProto_Sequence; struct TypeProto; struct ValueInfoProto; struct ValueInfoProtos; // RepeatedPtrField +struct FunctionProto; struct InferenceContext; class GraphInferencer; using InferenceFunction = std::function; @@ -146,6 +148,7 @@ struct ConfigOptions; struct DataTransferManager; struct IndexedSubGraph; struct IndexedSubGraph_MetaDef; +enum class IndexedSubGraph_SourceOfSchema : uint8_t; struct KernelCreateInfo; struct KernelDef; struct KernelDefBuilder; diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 4d40fcafaeea..382b3ac93252 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -304,6 +304,11 @@ struct ProviderHost { virtual int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0; virtual ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) = 0; + // OperatorSetIdProto + virtual std::string* OperatorSetIdProto__mutable_domain(ONNX_NAMESPACE::OperatorSetIdProto* p) = 0; + virtual void OperatorSetIdProto__set_version(ONNX_NAMESPACE::OperatorSetIdProto* p, int64_t version) = 0; + virtual int64_t OperatorSetIdProto__version(const ONNX_NAMESPACE::OperatorSetIdProto* p) = 0; + #if !defined(DISABLE_OPTIONAL_TYPE) // TypeProto_Optional virtual const ONNX_NAMESPACE::TypeProto& TypeProto_Optional__elem_type(const ONNX_NAMESPACE::TypeProto_Optional* p) = 0; @@ -420,6 +425,11 @@ struct ProviderHost { virtual void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) = 0; virtual ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) = 0; + virtual const ONNX_NAMESPACE::OperatorSetIdProto& ModelProto__opset_import(const ONNX_NAMESPACE::ModelProto* p, int index) = 0; + virtual ONNX_NAMESPACE::OperatorSetIdProto* ModelProto__mutable_opset_import(ONNX_NAMESPACE::ModelProto* p, int index) = 0; + virtual int ModelProto__opset_import_size(const ONNX_NAMESPACE::ModelProto* p) = 0; + virtual ONNX_NAMESPACE::OperatorSetIdProto* ModelProto__add_opset_import(ONNX_NAMESPACE::ModelProto* p) = 0; + // NodeProto virtual std::unique_ptr NodeProto__construct() = 0; virtual void NodeProto__operator_delete(ONNX_NAMESPACE::NodeProto* p) = 0; @@ -427,6 +437,7 @@ struct ProviderHost { virtual int NodeProto__attribute_size(ONNX_NAMESPACE::NodeProto* p) = 0; virtual const ONNX_NAMESPACE::AttributeProto& NodeProto__attribute(const ONNX_NAMESPACE::NodeProto* p, int index) const = 0; virtual ONNX_NAMESPACE::AttributeProto* NodeProto__mutable_attribute(ONNX_NAMESPACE::NodeProto* p, int index) = 0; + virtual ONNX_NAMESPACE::AttributeProto* NodeProto__add_attribute(ONNX_NAMESPACE::NodeProto* p) = 0; // TensorProto virtual std::unique_ptr TensorProto__construct() = 0; @@ -495,6 +506,64 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) = 0; + // FunctionProto + virtual std::unique_ptr FunctionProto__construct() = 0; + virtual void FunctionProto__operator_delete(ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual bool FunctionProto__SerializeToString(const ONNX_NAMESPACE::FunctionProto* p, std::string& string) = 0; + virtual bool FunctionProto__SerializeToOstream(const ONNX_NAMESPACE::FunctionProto* p, std::ostream& output) = 0; + virtual bool FunctionProto__ParseFromString(ONNX_NAMESPACE::FunctionProto* p, const std::string& data) = 0; + virtual std::string FunctionProto__SerializeAsString(const ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual bool FunctionProto__has_name(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual const std::string& FunctionProto__name(const ONNX_NAMESPACE::FunctionProto* p) const = 0; + virtual void FunctionProto__set_name(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& name) = 0; + + virtual bool FunctionProto__has_doc_string(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual const std::string& FunctionProto__doc_string(const ONNX_NAMESPACE::FunctionProto* p) const = 0; + virtual void FunctionProto__set_doc_string(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& doc_string) = 0; + + virtual bool FunctionProto__has_domain(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual const std::string& FunctionProto__domain(const ONNX_NAMESPACE::FunctionProto* p) const = 0; + virtual void FunctionProto__set_domain(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& domain) = 0; + + virtual const std::string& FunctionProto__input(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual std::string* FunctionProto__mutable_input(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__input_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual void FunctionProto__add_input(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0; + + virtual const std::string& FunctionProto__output(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual std::string* FunctionProto__mutable_output(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__output_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual void FunctionProto__add_output(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0; + + virtual const std::string& FunctionProto__attribute(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual std::string* FunctionProto__mutable_attribute(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__attribute_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual void FunctionProto__add_attribute(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0; + + virtual const ONNX_NAMESPACE::AttributeProto& FunctionProto__attribute_proto(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual ONNX_NAMESPACE::AttributeProto* FunctionProto__mutable_attribute_proto(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__attribute_proto_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::AttributeProto* FunctionProto__add_attribute_proto(ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual const ONNX_NAMESPACE::NodeProto& FunctionProto__node(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual ONNX_NAMESPACE::NodeProto* FunctionProto__mutable_node(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__node_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::NodeProto* FunctionProto__add_node(ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual const ONNX_NAMESPACE::ValueInfoProto& FunctionProto__value_info(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual ONNX_NAMESPACE::ValueInfoProtos* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::ValueInfoProto* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__value_info_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::ValueInfoProto* FunctionProto__add_value_info(ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual const ONNX_NAMESPACE::StringStringEntryProto& FunctionProto__metadata_props(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProtos* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) = 0; // ConfigOptions @@ -546,6 +615,9 @@ struct ProviderHost { virtual void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr&& meta_def_) = 0; virtual const IndexedSubGraph_MetaDef* IndexedSubGraph__GetMetaDef(const IndexedSubGraph* p) = 0; + virtual void IndexedSubGraph__SetSchemaSource(IndexedSubGraph* p, IndexedSubGraph_SourceOfSchema schema_source) = 0; + virtual IndexedSubGraph_SourceOfSchema IndexedSubGraph__GetSchemaSource(const IndexedSubGraph* p) = 0; + // KernelDef virtual void KernelDef__operator_delete(KernelDef* p) = 0; virtual int KernelDef__ExecQueueId(const KernelDef* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index fb3b274d9b80..de6c1da1d643 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -80,6 +80,15 @@ struct StringStringEntryProtos final { PROVIDER_DISALLOW_ALL(StringStringEntryProtos) }; + +struct OperatorSetIdProto final { + std::string* mutable_domain() { return g_host->OperatorSetIdProto__mutable_domain(this); } + void set_version(int64_t version) { return g_host->OperatorSetIdProto__set_version(this, version); } + int64_t version() { return g_host->OperatorSetIdProto__version(this); } + + PROVIDER_DISALLOW_ALL(OperatorSetIdProto) +}; + struct AttributeProto final { static std::unique_ptr Create() { return g_host->AttributeProto__construct(); } void operator=(const AttributeProto& v) { g_host->AttributeProto__operator_assign(this, v); } @@ -178,6 +187,11 @@ struct ModelProto final { void set_ir_version(int64_t value) { return g_host->ModelProto__set_ir_version(this, value); } + const OperatorSetIdProto& opset_import(int index) const { return g_host->ModelProto__opset_import(this, index); } + OperatorSetIdProto* mutable_opset_import(int index) { return g_host->ModelProto__mutable_opset_import(this, index); } + int opset_import_size() const { return g_host->ModelProto__opset_import_size(this); } + OperatorSetIdProto* add_opset_import() { return g_host->ModelProto__add_opset_import(this); } + ModelProto() = delete; ModelProto(const ModelProto&) = delete; void operator=(const ModelProto&) = delete; @@ -190,6 +204,7 @@ struct NodeProto final { int attribute_size() { return g_host->NodeProto__attribute_size(this); } const AttributeProto& attribute(int index) const { return g_host->NodeProto__attribute(this, index); } AttributeProto* mutable_attribute(int index) { return g_host->NodeProto__mutable_attribute(this, index); } + AttributeProto* add_attribute() { return g_host->NodeProto__add_attribute(this); } NodeProto() = delete; NodeProto(const NodeProto&) = delete; @@ -372,6 +387,69 @@ struct ValueInfoProtos final { PROVIDER_DISALLOW_ALL(ValueInfoProtos) }; + +struct FunctionProto final { + static std::unique_ptr Create() { return g_host->FunctionProto__construct(); } + static void operator delete(void* p) { g_host->FunctionProto__operator_delete(reinterpret_cast(p)); } + + bool SerializeToString(std::string& string) const { return g_host->FunctionProto__SerializeToString(this, string); } + bool SerializeToOstream(std::ostream& output) const { return g_host->FunctionProto__SerializeToOstream(this, output); } + bool ParseFromString(const std::string& data) { return g_host->FunctionProto__ParseFromString(this, data); } + std::string SerializeAsString() const { return g_host->FunctionProto__SerializeAsString(this); } + + bool has_name() const { return g_host->FunctionProto__has_name(this); } + const std::string& name() const { return g_host->FunctionProto__name(this); } + void set_name(const std::string& name) { g_host->FunctionProto__set_name(this, name); } + + bool has_doc_string() const { return g_host->FunctionProto__has_doc_string(this); } + const std::string& doc_string() const { return g_host->FunctionProto__doc_string(this); } + void set_doc_string(const std::string& doc_string) { g_host->FunctionProto__set_doc_string(this, doc_string); } + + bool has_domain() const { return g_host->FunctionProto__has_domain(this); } + const std::string& domain() const { return g_host->FunctionProto__domain(this); } + void set_domain(const std::string& domain) { g_host->FunctionProto__set_domain(this, domain); } + + const std::string& input(int index) const { return g_host->FunctionProto__input(this, index); } + std::string* mutable_input(int index) { return g_host->FunctionProto__mutable_input(this, index); } + int input_size() const { return g_host->FunctionProto__input_size(this); } + void add_input(const std::string& value) { g_host->FunctionProto__add_input(this, value); } + + const std::string& output(int index) const { return g_host->FunctionProto__output(this, index); } + std::string* mutable_output(int index) { return g_host->FunctionProto__mutable_output(this, index); } + int output_size() const { return g_host->FunctionProto__output_size(this); } + void add_output(const std::string& value) { g_host->FunctionProto__add_output(this, value); } + + const std::string& attribute(int index) const { return g_host->FunctionProto__attribute(this, index); } + std::string* mutable_attribute(int index) { return g_host->FunctionProto__mutable_attribute(this, index); } + int attribute_size() const { return g_host->FunctionProto__attribute_size(this); } + void add_attribute(const std::string& value) { g_host->FunctionProto__add_attribute(this, value); } + + const AttributeProto& attribute_proto(int index) const { return g_host->FunctionProto__attribute_proto(this, index); } + AttributeProto* mutable_attribute_proto(int index) { return g_host->FunctionProto__mutable_attribute_proto(this, index); } + int attribute_proto_size() const { return g_host->FunctionProto__attribute_proto_size(this); } + AttributeProto* add_attribute_proto() { return g_host->FunctionProto__add_attribute_proto(this); } + + const NodeProto& node(int index) const { return g_host->FunctionProto__node(this, index); } + NodeProto* mutable_node(int index) { return g_host->FunctionProto__mutable_node(this, index); } + int node_size() const { return g_host->FunctionProto__node_size(this); } + NodeProto* add_node() { return g_host->FunctionProto__add_node(this); } + + const ValueInfoProto& value_info(int index) const { return g_host->FunctionProto__value_info(this, index); } + ValueInfoProtos* mutable_value_info() { return g_host->FunctionProto__mutable_value_info(this); } + ValueInfoProto* mutable_value_info(int index) { return g_host->FunctionProto__mutable_value_info(this, index); } + int value_info_size() const { return g_host->FunctionProto__value_info_size(this); } + ValueInfoProto* add_value_info() { return g_host->FunctionProto__add_value_info(this); } + + const StringStringEntryProto& metadata_props(int index) const { return g_host->FunctionProto__metadata_props(this, index); } + StringStringEntryProtos* mutable_metadata_props() { return g_host->FunctionProto__mutable_metadata_props(this); } + StringStringEntryProto* mutable_metadata_props(int index) { return g_host->FunctionProto__mutable_metadata_props(this, index); } + int metadata_props_size() const { return g_host->FunctionProto__metadata_props_size(this); } + StringStringEntryProto* add_metadata_props() { return g_host->FunctionProto__add_metadata_props(this); } + + FunctionProto() = delete; + FunctionProto(const FunctionProto&) = delete; + void operator=(const FunctionProto&) = delete; +}; } // namespace ONNX_NAMESPACE namespace onnxruntime { @@ -449,6 +527,12 @@ struct IndexedSubGraph_MetaDef final { void operator=(const IndexedSubGraph_MetaDef&) = delete; }; +enum class IndexedSubGraph_SourceOfSchema : uint8_t { + CREATE, + REUSE_OR_CREATE, + EXISTING, +}; + struct IndexedSubGraph final { static std::unique_ptr Create() { return g_host->IndexedSubGraph__construct(); } static void operator delete(void* p) { g_host->IndexedSubGraph__operator_delete(reinterpret_cast(p)); } @@ -458,6 +542,9 @@ struct IndexedSubGraph final { void SetMetaDef(std::unique_ptr&& meta_def_) { return g_host->IndexedSubGraph__SetMetaDef(this, std::move(*reinterpret_cast*>(&meta_def_))); } const IndexedSubGraph_MetaDef* GetMetaDef() const { return reinterpret_cast(g_host->IndexedSubGraph__GetMetaDef(this)); } + void SetSchemaSource(IndexedSubGraph_SourceOfSchema schema_source) { return g_host->IndexedSubGraph__SetSchemaSource(this, schema_source); } + IndexedSubGraph_SourceOfSchema GetSchemaSource() const { return g_host->IndexedSubGraph__GetSchemaSource(this); } + IndexedSubGraph() = delete; IndexedSubGraph(const IndexedSubGraph&) = delete; void operator=(const IndexedSubGraph&) = delete; diff --git a/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc b/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc new file mode 100644 index 000000000000..ab31aa313cf6 --- /dev/null +++ b/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc @@ -0,0 +1,682 @@ +// Standard headers/libs. +#include +#include +#include +#include + +// 3rd-party headers/libs. +#include + +#include "ep_context_utils.h" + +namespace onnxruntime { + +constexpr const char* kVitisAI = "vitisai"; + +std::unique_ptr ConvertIndexedSubGraphToFunctionProto( + const IndexedSubGraph& sub_graph, const Graph& parent_graph) { + auto p_func_proto = ONNX_NAMESPACE::FunctionProto::Create(); + auto* p_meta_def = const_cast(sub_graph.GetMetaDef()); + if (p_meta_def) { + p_func_proto->set_name(p_meta_def->name()); + p_func_proto->set_domain(p_meta_def->domain()); + for (const auto& input : p_meta_def->inputs()) { + p_func_proto->add_input(input); + } + auto* p_metadata_props_0 = p_func_proto->add_metadata_props(); + *(p_metadata_props_0->mutable_key()) = "meta_def_inputs_size"; + *(p_metadata_props_0->mutable_value()) = std::to_string(p_meta_def->inputs().size()); + for (const auto& output : p_meta_def->outputs()) { + p_func_proto->add_output(output); + } + // XXX: SerDes with different fields. + for (const auto& initializer : p_meta_def->constant_initializers()) { + p_func_proto->add_input(initializer); + } + // XXX: SerDes with different numbers of fields. + for (const auto& attr_pair : p_meta_def->attributes()) { + p_func_proto->add_attribute(attr_pair.first); + auto* p_attr_proto = p_func_proto->add_attribute_proto(); + *p_attr_proto = attr_pair.second; + } + p_func_proto->set_doc_string(p_meta_def->doc_string()); + // "since_version" + auto* p_metadata_props_1 = p_func_proto->add_metadata_props(); + *(p_metadata_props_1->mutable_key()) = "meta_def_since_version"; + *(p_metadata_props_1->mutable_value()) = std::to_string(p_meta_def->since_version()); + // "status" + auto* p_metadata_props_2 = p_func_proto->add_metadata_props(); + *(p_metadata_props_2->mutable_key()) = "meta_def_status"; + *(p_metadata_props_2->mutable_value()) = + std::to_string(static_cast(p_meta_def->status())); + // TODO: `MetaDef::type_and_shape_inference_function`. + } + auto p_parent_graph_proto = parent_graph.ToGraphProto(); + for (auto node_index : const_cast(sub_graph).Nodes()) { + auto* p_node_proto = p_parent_graph_proto->mutable_node(static_cast(node_index)); + auto* p_attr_proto = p_node_proto->add_attribute(); + p_attr_proto->set_name("parent_graph_node_index"); + p_attr_proto->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_proto->set_i(node_index); + *(p_func_proto->add_node()) = *p_node_proto; + } +#if 0 + // Alternative. + for (const auto node_index : sub_graph.Nodes()) { + const auto* p_node = parent_graph.GetNode(node_index); + auto p_node_proto = ONNX_NAMESPACE::NodeProto::Create(); + // XXX + p_node->ToProto(*p_node_proto, true); + auto* p_attr_proto = p_node_proto->add_attribute(); + p_attr_proto->set_name("parent_graph_node_index"); + p_attr_proto->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_proto->set_i(node_index); + *(p_func_proto.add_node()) = *p_node_proto; + } +#endif + auto* p_metadata_props_3 = p_func_proto->add_metadata_props(); + *(p_metadata_props_3->mutable_key()) = "schema_source"; + *(p_metadata_props_3->mutable_value()) = + std::to_string(static_cast(sub_graph.GetSchemaSource())); + return p_func_proto; +} + +std::unique_ptr ConvertFunctionProtoToIndexedSubGraph( + const std::unique_ptr& p_func_proto) { + auto p_isg = IndexedSubGraph::Create(); + // "meta_def_inputs_size" (optional) and "schema_source". + int func_metadata_props_size = p_func_proto->metadata_props_size(); + // Precisely, func_metadata_props_size == 4, which implies + // `IndexedSubGraph::meta_def_` is not null and `IndexedSubGraph::nodes` > 1. + if (func_metadata_props_size > 1) { + auto& prop0 = const_cast(p_func_proto->metadata_props(0)); + int isg_meta_def_inputs_size = std::stoi(*(prop0.mutable_value())); + auto p_meta_def = IndexedSubGraph_MetaDef::Create(); + p_meta_def->name() = p_func_proto->name(); + p_meta_def->domain() = p_func_proto->domain(); + auto& prop1 = const_cast(p_func_proto->metadata_props(1)); + p_meta_def->since_version() = std::stoi(*(prop1.mutable_value())); + auto& prop2 = const_cast(p_func_proto->metadata_props(2)); + p_meta_def->status() = static_cast(std::stoi(*(prop2.mutable_value()))); + auto& meta_def_inputs = p_meta_def->inputs(); + for (int i = 0; i < isg_meta_def_inputs_size; i++) { + meta_def_inputs.push_back(p_func_proto->input(i)); + } + auto& meta_def_outputs = p_meta_def->outputs(); + for (int i = 0, l = p_func_proto->output_size(); i < l; i++) { + meta_def_outputs.push_back(p_func_proto->output(i)); + } + auto& meta_def_initializers = p_meta_def->constant_initializers(); + for (int i = isg_meta_def_inputs_size, l = p_func_proto->input_size(); i < l; i++) { + meta_def_initializers.push_back(p_func_proto->input(i)); + } + auto& meta_def_attrs = p_meta_def->attributes(); + for (int i = 0, l = p_func_proto->attribute_size(); i < l; i++) { + meta_def_attrs.emplace(p_func_proto->attribute(i), p_func_proto->attribute_proto(i)); + } + p_meta_def->doc_string() = p_func_proto->doc_string(); + // TODO: `IndexedSubGraph::type_and_shape_inference_function`. + p_isg->SetMetaDef(std::move(p_meta_def)); + } + auto& isg_nodes = p_isg->Nodes(); + for (int i = 0, l = p_func_proto->node_size(); i < l; i++) { + const auto& node_proto = p_func_proto->node(i); + isg_nodes.push_back( + node_proto.attribute(const_cast(node_proto).attribute_size() - 1).i()); + } + auto schema_source = static_cast( + std::stoi(*(const_cast(p_func_proto->metadata_props(func_metadata_props_size - 1)).mutable_value()))); + p_isg->SetSchemaSource(schema_source); + return p_isg; +} + +std::string SerializeCapabilities( + const std::vector>& capability_ptrs, + const Graph& graph) { + std::stringstream ss; + for (const auto& p : capability_ptrs) { + auto& p_subgraph = p->SubGraph(); + auto p_func_proto = ConvertIndexedSubGraphToFunctionProto(*p_subgraph, graph); + std::string func_proto_buf; + p_func_proto->SerializeToString(func_proto_buf); + size_t buf_len = func_proto_buf.length(); + ss.write(reinterpret_cast(&buf_len), sizeof(buf_len)); + ss.write(func_proto_buf.data(), buf_len); + } + if (!ss.good()) { + ORT_THROW("Serialization stream bad"); + } + return ss.str(); +} + +void DeserializeCapabilities(const std::string& ser_capabilities, + std::vector>& capability_ptrs) { + std::istringstream ss(ser_capabilities); + while (!ss.eof()) { + size_t buf_len; + ss.read(reinterpret_cast(&buf_len), sizeof(buf_len)); + std::string buf(buf_len, '\0'); + ss.read(&buf[0], buf_len); + auto p_func_proto = ONNX_NAMESPACE::FunctionProto::Create(); + p_func_proto->ParseFromString(buf); + auto p_subgraph = ConvertFunctionProtoToIndexedSubGraph(p_func_proto); + capability_ptrs.push_back(ComputeCapability::Create(std::move(p_subgraph))); + } +} + +std::string SerializeOrigialGraph(const GraphViewer& graph_viewer) { + // XXX: Will Steps 1/2/3 suffice for restoring a model/graph later? + // Any information loss or mismatch? + // Step 1 + const Graph& orig_graph = graph_viewer.GetGraph(); + // Step 2 + const Model& orig_model = orig_graph.GetModel(); + // Step 3 + auto p_orig_model_proto = const_cast(orig_model).ToProto(); + if (p_orig_model_proto->opset_import_size() == 0) { + for (const auto& it : graph_viewer.DomainToVersionMap()) { + auto* p_opset_import = p_orig_model_proto->add_opset_import(); + *(p_opset_import->mutable_domain()) = it.first; + p_opset_import->set_version(it.second); + } + } + + nlohmann::json j_obj; + if (p_orig_model_proto->opset_import_size() > 0) { + for (int i = 0, n = p_orig_model_proto->opset_import_size(); i < n; ++i) { + auto& op_set_id_proto = const_cast(p_orig_model_proto->opset_import(i)); + j_obj[*op_set_id_proto.mutable_domain()] = std::to_string(op_set_id_proto.version()); + } + } + j_obj["orig_graph_name"] = graph_viewer.Name(); + // TODO: platform dependency (Linux vs Windows). + j_obj["orig_model_path"] = graph_viewer.ModelPath().string(); + + // XXX: `ModelProto::SerializeToString` will lose some info, + // e.g., ModelProto.opset_import. + std::string ser_buf; + p_orig_model_proto->SerializeToString(ser_buf); + j_obj["orig_model_proto_ser_str"] = ser_buf; + + return j_obj.dump(-1, ' ', false, nlohmann::json::error_handler_t::replace); +} + +// Ref.: `CreateEpContextModel()` in the file "graph_partitioner.cc". +ONNX_NAMESPACE::ModelProto* CreateEPContexModel( + const GraphViewer& graph_viewer, + const std::string& serialized_ctx_cache, + const std::string& ctx_cache_file_loc, + const int64_t embed_mode, + const std::string& backend_cache_dir, + const std::string& backend_cache_key, + bool saving_orig_graph, + const logging::Logger* p_logger) { + LOGS_DEFAULT(VERBOSE) << "[VitisAI EP]Creating EP context node"; + // Create a new graph/model, reusing the graph name, + // the op-domain-to-opset-version map, + // and the op schema registry of the current graph. + // XXX: This approach (immediately below) has a memory fault issue (std::bad_alloc). + // auto& ep_ctx_graph = graph_viewer.CreateModel(*p_logger)->MainGraph(); + // This apporach (immediately below) has no memory falut issue. + auto p_temp_model = graph_viewer.CreateModel(*p_logger); + auto& ep_ctx_graph = p_temp_model->MainGraph(); + + const auto& graph_inputs = graph_viewer.GetInputs(); + std::vector input_node_arg_ptrs; + input_node_arg_ptrs.reserve(graph_inputs.size()); + // XXX: vs `GraphViewer::GetInputsIncludingInitializers()`. + for (const auto* p_node_arg : graph_inputs) { + auto& temp_node_arg = ep_ctx_graph.GetOrCreateNodeArg( + p_node_arg->Name(), p_node_arg->TypeAsProto()); + input_node_arg_ptrs.push_back(&temp_node_arg); + } + const auto& graph_outputs = graph_viewer.GetOutputs(); + std::vector output_node_arg_ptrs; + output_node_arg_ptrs.reserve(graph_outputs.size()); + for (const auto* p_node_arg : graph_outputs) { + auto& temp_node_arg = ep_ctx_graph.GetOrCreateNodeArg(p_node_arg->Name(), p_node_arg->TypeAsProto()); + output_node_arg_ptrs.push_back(&temp_node_arg); + } + + // Attr "embed_mode". + auto p_attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_0->set_name(kEmbedModeAttr); + // p_attr_0->set_type(onnx::AttributeProto_AttributeType_INT); + p_attr_0->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_0->set_i(embed_mode); + // Attr "ep_cache_context". + auto p_attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_1->set_name(kEPCacheContextAttr); + // p_attr_1->set_type(onnx::AttributeProto_AttributeType_STRING); + p_attr_1->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + // Relative to the ONNX model file. + p_attr_1->set_s( + embed_mode == 0 ? fs::path(ctx_cache_file_loc).filename().string() : serialized_ctx_cache); + // Attr "source". + auto p_attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_2->set_name(kSourceAttr); + // p_attr_2->set_type(onnx::AttributeProto_AttributeType_STRING); + p_attr_2->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + p_attr_2->set_s(kVitisAIExecutionProvider); + // Attr "onnx_model_filename". + auto p_attr_3 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_3->set_name(kONNXModelFileNameAttr); + // p_attr_3->set_type(onnx::AttributeProto_AttributeType_STRING); + p_attr_3->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + p_attr_3->set_s(graph_viewer.ModelPath().filename().string()); + // Attr "notes". + auto p_attr_4 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_4->set_name(kNotesAttr); + // p_attr_4->set_type(onnx::AttributeProto_AttributeType_STRING); + p_attr_4->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + // FIXME: 2G-limit of ProtoBuf. + if (saving_orig_graph) { + p_attr_4->set_s(SerializeOrigialGraph(graph_viewer)); + } else { + nlohmann::json j_obj; + j_obj["backend_cache_dir"] = backend_cache_dir; + j_obj["backend_cache_key"] = backend_cache_key; + p_attr_4->set_s(j_obj.dump(-1, ' ', false, nlohmann::json::error_handler_t::replace)); + } + + auto p_node_attrs = NodeAttributes::Create(); + constexpr int num_attrs = 5; + p_node_attrs->reserve(num_attrs); + p_node_attrs->emplace(kEmbedModeAttr, *p_attr_0); + p_node_attrs->emplace(kEPCacheContextAttr, *p_attr_1); + p_node_attrs->emplace(kSourceAttr, *p_attr_2); + p_node_attrs->emplace(kONNXModelFileNameAttr, *p_attr_3); + p_node_attrs->emplace(kNotesAttr, *p_attr_4); + + // Since we don't implement `IExecutionProvider::GetEpContextNodes()` and + // thus don't leverage `CreateEpContextModel()` in the file "graph_partitioner.cc", + // we specify a brand-new node name here. + ep_ctx_graph.AddNode(kEPContextOpName, kEPContextOp, "", input_node_arg_ptrs, output_node_arg_ptrs, p_node_attrs.get(), kEPContextOpDomain); + + auto res_status = ep_ctx_graph.Resolve(); + ORT_ENFORCE(res_status.IsOK(), res_status.ErrorMessage()); + LOGS_DEFAULT(VERBOSE) << "Created EP context model graph resolved"; + + auto p_ep_ctx_graph_viewer = ep_ctx_graph.CreateGraphViewer(); + auto p_temp_model_2 = p_ep_ctx_graph_viewer->CreateModel(*p_logger); + auto p_ep_ctx_model_proto = p_temp_model_2->ToProto(); + p_ep_ctx_graph_viewer->ToProto(*p_ep_ctx_model_proto->mutable_graph(), true, true); + p_ep_ctx_model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + + return p_ep_ctx_model_proto.release(); +} + +// Ref.: `static common::Status Save(Model& model, int fd)` in the file "model.h". +void DumpEPContextModel( + const std::unique_ptr& p_model_proto, const std::string& ep_ctx_model_file_loc) { + std::fstream dump_stream(ep_ctx_model_file_loc, std::ios::out | std::ios::trunc | std::ios::binary); + p_model_proto->SerializeToOstream(dump_stream); + LOGS_DEFAULT(VERBOSE) << "[VitisAI EP] Dumped " << ep_ctx_model_file_loc; +} + +const Node* GetEPContextNodePtr(const Graph& graph) { + // TODO: Support for multi-node EP context model. + for (const auto* p_node : graph.Nodes()) { + if (p_node->OpType() == kEPContextOp) { + return p_node; + } + } + return nullptr; +} + +bool ValidateEPContextNode(const Graph& graph) { + // TODO: Support for multi-node EP context model. + const auto* p_node = GetEPContextNodePtr(graph); + assert(p_node != nullptr); + auto& attrs = p_node->GetAttributes(); + assert(attrs.count(kEmbedModeAttr) > 0); + assert(attrs.count(kEPCacheContextAttr) > 0); + assert(attrs.count(kSourceAttr) > 0); + const auto& source_val = attrs.at(kSourceAttr).s(); + if (source_val == kVitisAIExecutionProvider) { + return true; + } + size_t vitisai_len = std::strlen(kVitisAI); + assert(source_val.length() == vitisai_len); + for (size_t i = 0; i < vitisai_len; ++i) { + assert(static_cast(std::tolower(source_val[i])) == kVitisAI[i]); + } + return true; +} + +// Ref.: `CreateEpContextModel()` in the file "graph_partitioner.cc". +void CreateEPContexNodes( + Graph* p_ep_ctx_graph, + const std::vector& fused_nodes_and_graphs, + const std::string& serialized_ctx_cache, + const std::string& ctx_cache_file_loc, + const int64_t embed_mode, + const std::string& backend_cache_dir, + const std::string& backend_cache_key, + bool saving_orig_graph, + const logging::Logger* p_logger) { + LOGS_DEFAULT(VERBOSE) << "[VitisAI EP]Creating EP context nodes"; + int fused_index = 0; + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + Node& fused_node = fused_node_graph.fused_node; + const auto& fused_name = fused_node.Name(); + const GraphViewer& graph_viewer = fused_node_graph.filtered_graph; + // FIXME + const auto& graph_inputs = graph_viewer.GetInputs(); + std::vector input_node_arg_ptrs; + input_node_arg_ptrs.reserve(graph_inputs.size()); + // XXX: vs `GraphViewer::GetInputsIncludingInitializers()`. + for (const auto* p_node_arg : graph_inputs) { + auto& temp_node_arg = p_ep_ctx_graph->GetOrCreateNodeArg( + p_node_arg->Name(), p_node_arg->TypeAsProto()); + input_node_arg_ptrs.push_back(&temp_node_arg); + } + const auto& graph_outputs = graph_viewer.GetOutputs(); + std::vector output_node_arg_ptrs; + output_node_arg_ptrs.reserve(graph_outputs.size()); + for (const auto* p_node_arg : graph_outputs) { + auto& temp_node_arg = p_ep_ctx_graph->GetOrCreateNodeArg(p_node_arg->Name(), p_node_arg->TypeAsProto()); + output_node_arg_ptrs.push_back(&temp_node_arg); + } + + auto p_node_attrs = NodeAttributes::Create(); + if (fused_index == 0) { + p_node_attrs->reserve(7); + // Attr "ep_cache_context". + auto p_attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_1->set_name(kEPCacheContextAttr); + p_attr_1->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + // Relative to the ONNX model file. + p_attr_1->set_s( + embed_mode == 0 ? fs::path(ctx_cache_file_loc).filename().string() : serialized_ctx_cache); + p_node_attrs->emplace(kEPCacheContextAttr, *p_attr_1); + // Attr "notes". + auto p_attr_4 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_4->set_name(kNotesAttr); + p_attr_4->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + // FIXME: 2G-limit of ProtoBuf. + if (saving_orig_graph) { + p_attr_4->set_s(SerializeOrigialGraph(graph_viewer)); + } else { + nlohmann::json j_obj; + j_obj["backend_cache_dir"] = backend_cache_dir; + j_obj["backend_cache_key"] = backend_cache_key; + p_attr_4->set_s(j_obj.dump(-1, ' ', false, nlohmann::json::error_handler_t::replace)); + } + p_node_attrs->emplace(kNotesAttr, *p_attr_4); + // Attr "main_context". + auto p_attr_5 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_5->set_name(kMainContextAttr); + p_attr_5->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_5->set_i(1); + p_node_attrs->emplace(kMainContextAttr, *p_attr_5); + } else { + p_node_attrs->reserve(5); + // Attr "main_context". + auto p_attr_5 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_5->set_name(kMainContextAttr); + p_attr_5->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_5->set_i(0); + p_node_attrs->emplace(kMainContextAttr, *p_attr_5); + } + // Attr "embed_mode". + auto p_attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_0->set_name(kEmbedModeAttr); + p_attr_0->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_0->set_i(embed_mode); + p_node_attrs->emplace(kEmbedModeAttr, *p_attr_0); + // Attr "source". + auto p_attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_2->set_name(kSourceAttr); + p_attr_2->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + p_attr_2->set_s(kVitisAIExecutionProvider); + p_node_attrs->emplace(kSourceAttr, *p_attr_2); + // Attr "onnx_model_filename". + auto p_attr_3 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_3->set_name(kONNXModelFileNameAttr); + p_attr_3->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + p_attr_3->set_s(graph_viewer.ModelPath().filename().string()); + p_node_attrs->emplace(kONNXModelFileNameAttr, *p_attr_3); + // Attr "partition_name". + auto p_attr_6 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_6->set_name(kPartitionNameAttr); + p_attr_6->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + p_attr_6->set_s(fused_name); + p_node_attrs->emplace(kPartitionNameAttr, *p_attr_6); + + p_ep_ctx_graph->AddNode(fused_name, kEPContextOp, "", input_node_arg_ptrs, output_node_arg_ptrs, p_node_attrs.get(), kEPContextOpDomain); + + ++fused_index; + } + auto res_status = p_ep_ctx_graph->Resolve(); + ORT_ENFORCE(res_status.IsOK(), res_status.ErrorMessage()); + LOGS_DEFAULT(VERBOSE) << "Created EP context model graph resolved"; +} + +std::string RetrieveEPContextCache( + const Graph& graph, const PathString& ep_ctx_model_loc, bool binary_mode) { + // TODO: Support for multi-node EP context model. + const auto* p_node = GetEPContextNodePtr(graph); + const auto& attrs = p_node->GetAttributes(); + int64_t embed_mode = attrs.at(kEmbedModeAttr).i(); + const std::string& ep_ctx_cache = attrs.at(kEPCacheContextAttr).s(); + if (embed_mode) { + return ep_ctx_cache; + } + fs::path ep_ctx_fs_path(ep_ctx_model_loc); + // Attr "ep_cache_context" stores a relative path. + ep_ctx_fs_path.replace_filename(fs::path(ep_ctx_cache)); + // TODO: Validaion of the file location to make sure security is met. + if (!fs::exists(ep_ctx_fs_path) || !fs::is_regular_file(ep_ctx_fs_path)) { + ORT_THROW("File for EP context cache is missing"); + } + auto open_mode = binary_mode ? (std::ios::in | std::ios::binary) : std::ios::in; + std::ifstream ifs(ep_ctx_fs_path.string().c_str(), open_mode); + if (!ifs.is_open()) { + ORT_THROW("Exception opening EP context cache file"); + } + ifs.seekg(0, ifs.end); + std::streampos cache_len = ifs.tellg(); + if (cache_len == -1) { + ifs.close(); + ORT_THROW("Error when operating EP context cache file"); + } else if (cache_len == 0) { + ifs.close(); + LOGS_DEFAULT(WARNING) << "Empty EP context cache file: " << ep_ctx_fs_path.string(); + return ""; + } + ifs.seekg(0, ifs.beg); + char* buf = new char[static_cast(cache_len)]; + ifs.read(buf, cache_len); + if (!ifs.good()) { + ifs.close(); + ORT_THROW("Exception reading EP context cache file"); + } + ifs.close(); + std::string cache_payload(buf); + delete[] buf; + return cache_payload; +} + +void RetrieveBackendCacheInfo(const Graph& graph, std::string& cache_dir, std::string& cache_key) { + // TODO: Support for multi-node EP context model. + const auto* p_node = GetEPContextNodePtr(graph); + if (p_node == nullptr) { + LOGS_DEFAULT(WARNING) << "Failed to retrieve cache info due to no EP context nodes"; + return; + } + const auto& attrs = p_node->GetAttributes(); + const auto& notes_str = attrs.at(kNotesAttr).s(); + nlohmann::json j_obj = nlohmann::json::parse(notes_str); + cache_dir = j_obj["backend_cache_dir"].get(); + cache_key = j_obj["backend_cache_key"].get(); + if (cache_dir.empty()) { + LOGS_DEFAULT(WARNING) << "Retrieved backend cache dir empty"; + } + if (cache_key.empty()) { + LOGS_DEFAULT(WARNING) << "Retrieved backend cache key empty"; + } +} + +std::unique_ptr RetrieveOriginalGraph(const Graph& ep_ctx_graph) { + // TODO: Support for multi-node EP context model. + const auto* p_node = GetEPContextNodePtr(ep_ctx_graph); + const auto& attrs = p_node->GetAttributes(); + const auto& notes_str = attrs.at(kNotesAttr).s(); + nlohmann::json j_obj = nlohmann::json::parse(notes_str); + + const auto& orig_model_path = j_obj["orig_model_path"].get(); + bool model_loaded = false; + auto p_model_proto = ONNX_NAMESPACE::ModelProto::Create(); + if (!orig_model_path.empty() && fs::exists(orig_model_path) && fs::is_regular_file(orig_model_path)) { + auto load_status = Model::Load(ToPathString(orig_model_path), *p_model_proto); + model_loaded = load_status.IsOK(); + } + if (!model_loaded) { + p_model_proto->ParseFromString(j_obj["orig_model_proto_ser_str"].get()); + if (p_model_proto->opset_import_size() == 0) { + for (auto& elem : j_obj.items()) { + if (elem.key() == "orig_model_path" || elem.key() == "orig_graph_name" || elem.key() == "orig_model_proto_ser_str") { + continue; + } + auto* p_op_set_id_proto = p_model_proto->add_opset_import(); + *(p_op_set_id_proto->mutable_domain()) = elem.key(); + p_op_set_id_proto->set_version(std::stoll(elem.value().get())); + } + } + } + auto& logger = logging::LoggingManager::DefaultLogger(); + auto p_model = Model::Create(std::move(*p_model_proto), ToPathString(orig_model_path), nullptr, logger); + auto& graph = p_model->MainGraph(); + graph.ToGraphProto()->set_name(j_obj["orig_graph_name"].get()); + + return graph.CreateGraphViewer(); +} + +bool GraphHasEPContextNode(const Graph& graph) { + size_t vitisai_len = std::strlen(kVitisAI); + for (const auto* p_node : graph.Nodes()) { + if (p_node->OpType() != kEPContextOp) { + continue; + } + const auto& attrs = p_node->GetAttributes(); + if (attrs.count(kSourceAttr) == 0) { + continue; + } + const auto& source_val = attrs.at(kSourceAttr).s(); + if (source_val == kVitisAIExecutionProvider) { + return true; + } + if (source_val.length() != vitisai_len) { + continue; + } + size_t j = 0; + do { + if (static_cast(std::tolower(source_val[j])) != kVitisAI[j]) { + break; + } + ++j; + } while (j < vitisai_len); + if (j == vitisai_len) { + return true; + } + } + return false; +} + +bool FusedGraphHasEPContextNode( + const std::vector& fused_nodes_and_graphs) { + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + bool has_node = GraphHasEPContextNode(fused_node_graph.filtered_graph.get().GetGraph()); + if (has_node) { + return true; + } + } + return false; +} + +const fs::path& GetTopLevelModelPath(const GraphViewer& graph_viewer) { + const auto& graph = graph_viewer.GetGraph(); + const Graph* p_graph = &graph; + while (p_graph->IsSubgraph()) { + p_graph = p_graph->ParentGraph(); + } + return p_graph->ModelPath(); +} + +bool GetEPContextModelFileLocation( + const std::string& ep_ctx_model_path_cfg, + const PathString& model_path_str, + bool is_ep_ctx_model, + PathString& ep_ctx_model_file_loc) { + if (!ep_ctx_model_file_loc.empty()) { + return true; + } + if (!ep_ctx_model_path_cfg.empty()) { + ep_ctx_model_file_loc = ToPathString(ep_ctx_model_path_cfg); + } else if (!model_path_str.empty()) { + if (is_ep_ctx_model) { + ep_ctx_model_file_loc = model_path_str; + } else { + // Two alternatives for this case. + // Alternative 1: + // 1) Implement/override the method `IExecutionProvider::GetEpContextNodes()`. + // 2) And follow how the default path is implemented in `CreateEpContextModel()` + // in the file "graph_partitioner.cc". + // 3) Model dump is not required. + // Alternative 2: + // 1) Do NOT implement/override `IExecutionProvider::GetEpContextNodes()`. + // 2) No need to follow `CreateEpContextModel()` in the file "graph_partitioner.cc", + // freely implement what the default path is like. + // 3) Model dump is required. +#if 0 + ep_ctx_model_file_loc = model_path_str + ToPathString("_ctx.onnx"); +#endif +#if 1 + fs::path model_fs_path(model_path_str); + fs::path ep_ctx_model_fs_path(model_fs_path.parent_path() / model_fs_path.stem()); + ep_ctx_model_fs_path += fs::path("_ctx.onnx"); + ep_ctx_model_file_loc = ToPathString(ep_ctx_model_fs_path.string()); +#endif + } + } + return !ep_ctx_model_file_loc.empty(); +} + +// The file for EP context cache is in the same folder as the EP context model file. +PathString GetEPContextCacheFileLocation( + const PathString& ep_ctx_model_file_loc, const PathString& model_path_str) { + if (!ep_ctx_model_file_loc.empty()) { + fs::path ep_ctx_model_fs_path(ep_ctx_model_file_loc); + fs::path ep_ctx_cache_fs_path(ep_ctx_model_fs_path.parent_path() / ep_ctx_model_fs_path.stem()); + ep_ctx_cache_fs_path += fs::path("__ep_ctx_cache.bin"); + return ToPathString(ep_ctx_cache_fs_path.string()); + } + fs::path model_fs_path(model_path_str); + fs::path ep_ctx_cache_fs_path(model_fs_path.parent_path() / model_fs_path.stem()); + ep_ctx_cache_fs_path += fs::path("__ep_ctx_cache.bin"); + return ToPathString(ep_ctx_cache_fs_path.string()); +} + +std::string Slurp(const fs::path& file_location, bool binary_mode) { + // std::filesystem::value_type == onnxruntime::PathChar == ORTCHAR_T + // std::filesystem::string_type == onnxruntime::PathString + // const char* location_str = PathToUTF8String(file_location.native()).c_str(); + std::ifstream ifs; + ifs.exceptions(std::ifstream::failbit | std::ifstream::badbit); + std::stringstream ss; + try { + auto open_mode = binary_mode ? (std::ios::in | std::ios::binary) : std::ios::in; + ifs.open(file_location.string().c_str(), open_mode); + ss << ifs.rdbuf(); + if (!ss.good()) { + LOGS_DEFAULT(WARNING) << "Failed to write to stream"; + } + ifs.close(); + } catch (std::system_error& se) { + LOGS_DEFAULT(WARNING) << "Failed to read " << file_location << ": " << se.code().message(); + } + return ss.str(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index e9ae93ded40c..1a3cc5979ff5 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -53,6 +53,8 @@ struct OrtVitisAIEpAPI { std::vector>* (*compile_onnx_model_with_options)( const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); uint32_t (*vaip_get_version)(); + void (*get_backend_compilation_cache)(const std::string& model_path, const onnxruntime::Graph& graph, const char* json_config, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data); + void (*restore_backend_compilation_cache)(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path); void Ensure() { if (handle_) return; @@ -77,6 +79,8 @@ struct OrtVitisAIEpAPI { } std::ignore = env.GetSymbolFromLibrary(handle_, "vaip_get_version", (void**)&vaip_get_version); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "get_compilation_cache", (void**)&get_backend_compilation_cache)); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "restore_compilation_cache", (void**)&restore_backend_compilation_cache)); } private: @@ -122,13 +126,7 @@ static std::string config_to_json_str(const onnxruntime::ProviderOptions& config vaip_core::DllSafe>> compile_onnx_model( const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { -#ifndef _WIN32 - auto model_path = graph_viewer.ModelPath().string(); -#else - using convert_t = std::codecvt_utf8; - std::wstring_convert strconverter; - auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().string()); -#endif + auto model_path = PathToUTF8String(ToPathString(graph_viewer.ModelPath().string())); if (s_library_vitisaiep.compile_onnx_model_with_options) { return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options)); } else { @@ -137,6 +135,17 @@ vaip_core::DllSafe>> c } } +void get_backend_compilation_cache(const onnxruntime::PathString& model_path_str, const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::ProviderOptions& options, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data) { + const std::string& model_path = PathToUTF8String(model_path_str); + const onnxruntime::Graph& graph = graph_viewer.GetGraph(); + const auto json_str = config_to_json_str(options); + s_library_vitisaiep.get_backend_compilation_cache(model_path, graph, json_str.c_str(), compiler_codes, cache_dir, cache_key, cache_data); +} + +void restore_backend_compilation_cache(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path) { + s_library_vitisaiep.restore_backend_compilation_cache(cache_dir, cache_key, cache_data, model_path); +} + struct MyCustomOpKernel : OpKernel { MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { op_kernel_ = @@ -218,7 +227,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { auto& logger = logging::LoggingManager::DefaultLogger(); auto& model = const_cast(const_model); auto model_proto = model.ToProto(); - auto file_path = model.MainGraph().ModelPath().string(); + auto file_path = ToPathString(model.MainGraph().ModelPath().string()); auto local_registries = IOnnxRuntimeOpSchemaRegistryList{model.MainGraph().GetSchemaRegistry()}; auto ret = Model::Create(std::move(*model_proto), file_path, &local_registries, logger); auto status = ret->MainGraph().Resolve(); diff --git a/onnxruntime/core/providers/vitisai/include/ep_context_utils.h b/onnxruntime/core/providers/vitisai/include/ep_context_utils.h new file mode 100644 index 000000000000..61a595cf1ae1 --- /dev/null +++ b/onnxruntime/core/providers/vitisai/include/ep_context_utils.h @@ -0,0 +1,81 @@ +#pragma once + +// Standard headers/libs. +#include +#include +#include +#include + +// 1st-party headers/libs. +#include "core/providers/shared_library/provider_api.h" + +namespace fs = std::filesystem; + +namespace onnxruntime { + +constexpr const uint8_t kXCCode = 1; +constexpr const uint8_t kDDCode = 2; +constexpr const uint8_t kVCode = 4; + +static constexpr const char* kEPContextOp = "EPContext"; +static constexpr const char* kMainContextAttr = "main_context"; +static constexpr const char* kEPCacheContextAttr = "ep_cache_context"; +static constexpr const char* kEmbedModeAttr = "embed_mode"; +static constexpr const char* kPartitionNameAttr = "partition_name"; +static constexpr const char* kSourceAttr = "source"; +static constexpr const char* kEPSDKVersionAttr = "ep_sdk_version"; +static constexpr const char* kONNXModelFileNameAttr = "onnx_model_filename"; +static constexpr const char* kNotesAttr = "notes"; +static constexpr const char* kEPContextOpDomain = "com.microsoft"; +static constexpr const char* kEPContextOpName = "VitisAIEPContextOp"; + +std::unique_ptr +ConvertIndexedSubGraphToFunctionProto(const IndexedSubGraph&, const Graph&); + +std::unique_ptr ConvertFunctionProtoToIndexedSubGraph( + const std::unique_ptr&); + +std::string SerializeCapabilities( + const std::vector>&, const Graph&); + +void DeserializeCapabilities( + const std::string&, std::vector>&); + +std::string SerializeOrigialGraph(const GraphViewer&); + +// Ref.: `CreateEpContextModel()` in the file "graph_partitioner.cc". +ONNX_NAMESPACE::ModelProto* CreateEPContexModel(const GraphViewer&, const std::string&, const std::string&, const int64_t, + const std::string&, const std::string&, bool, const logging::Logger*); + +// Ref.: `static common::Status Save(Model& model, int fd)` in the file "model.h". +void DumpEPContextModel(const std::unique_ptr&, const std::string&); + +const Node* GetEPContextNodePtr(const Graph&); + +bool ValidateEPContextNode(const Graph&); + +void CreateEPContexNodes(Graph*, const std::vector&, const std::string&, const std::string&, + const int64_t, const std::string&, const std::string&, bool, const logging::Logger*); + +std::string RetrieveEPContextCache(const Graph&, const PathString&, bool binary_mode = true); + +void RetrieveBackendCacheInfo(const Graph&, std::string&, std::string&); + +std::unique_ptr RetrieveOriginalGraph(const Graph&); + +bool GraphHasEPContextNode(const Graph&); + +bool FusedGraphHasEPContextNode( + const std::vector&); + +const fs::path& GetTopLevelModelPath(const GraphViewer&); + +bool GetEPContextModelFileLocation( + const std::string&, const PathString&, bool, PathString&); + +// The file for EP context cache is in the same folder as the EP context model file. +PathString GetEPContextCacheFileLocation(const PathString&, const PathString&); + +std::string Slurp(const fs::path&, bool binary_mode = false); + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index 1f8b8802e86b..3fdbc60bb0ee 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -14,3 +14,5 @@ void initialize_vitisai_ep(); vaip_core::DllSafe>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options); std::shared_ptr get_kernel_registry_vitisaiep(); const std::vector& get_domains_vitisaiep(); +void get_backend_compilation_cache(const onnxruntime::PathString& model_path_str, const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::ProviderOptions& options, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data); +void restore_backend_compilation_cache(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path); diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 6fc09f3495aa..f45b89649bfc 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -2,22 +2,43 @@ // Licensed under the MIT License. #include "vitisai_execution_provider.h" +// Standard headers/libs. #include #include #include +#include + +// 1st-party headers/libs. +#include "core/platform/env_var_utils.h" +#include "core/common/exceptions.h" #include "vaip/capability.h" #include "vaip/global_api.h" +#include "ep_context_utils.h" using namespace ONNX_NAMESPACE; +namespace fs = std::filesystem; + namespace onnxruntime { constexpr const char* VITISAI = "VITISAI"; VitisAIExecutionProvider::VitisAIExecutionProvider( const ProviderOptions& info) + // const ProviderOptions& info, const SessionOptions* p_sess_opts) : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, info_(info) { CreateKernelRegistry(); + + auto it = info_.find("ep_context_enable"); + ep_ctx_enabled_ = it != info_.end() && it->second == "1"; + it = info_.find("ep_context_embed_mode"); + ep_ctx_embed_mode_ = it != info_.end() && it->second != "0"; + // ep_ctx_embed_mode_ = it == info_.end() || it->second != "0"; + it = info_.find("ep_context_file_path"); + ep_ctx_model_path_cfg_ = it == info_.end() ? "" : it->second; + LOGS_DEFAULT(VERBOSE) << "EP Context cache enabled: " << ep_ctx_enabled_; + LOGS_DEFAULT(VERBOSE) << "EP context cache embed mode: " << ep_ctx_embed_mode_; + LOGS_DEFAULT(VERBOSE) << "User specified EP context cache path: " << ep_ctx_model_path_cfg_; } void VitisAIExecutionProvider::CreateKernelRegistry() { @@ -30,9 +51,115 @@ void VitisAIExecutionProvider::CreateKernelRegistry() { std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return get_kernel_registry_vitisaiep(); } +// This method is called after both `GetComputeCapabilityOps()` and `Compile()`. +// This timing is required to work with both compilation-based EPs and non-compilation-based EPs. +const InlinedVector VitisAIExecutionProvider::GetEpContextNodes() const { + InlinedVector ep_context_node_ptrs; + // All preconditions are supposed to have happened. + if (p_ep_ctx_model_) { + auto& graph = p_ep_ctx_model_->MainGraph(); + for (const auto* p_node : graph.Nodes()) { + ep_context_node_ptrs.push_back(p_node); + } + } + return ep_context_node_ptrs; +} + +void VitisAIExecutionProvider::LoadEPContexModelFromFile() const { + // XXX: should "p_ep_ctx_model_" be checked or not? + if (!p_ep_ctx_model_ && !ep_ctx_model_file_loc_.empty()) { + auto status = Model::Load(ep_ctx_model_file_loc_, *p_ep_ctx_model_proto_); + if (!status.IsOK()) { + ORT_THROW("Loading EP context model failed from ", PathToUTF8String(ep_ctx_model_file_loc_)); + } + p_ep_ctx_model_ = Model::Create(std::move(*p_ep_ctx_model_proto_), ep_ctx_model_file_loc_, nullptr, *GetLogger()); + LOGS_DEFAULT(VERBOSE) << "Loaded EP context model from: " << PathToUTF8String(ep_ctx_model_file_loc_); + } else if (ep_ctx_model_file_loc_.empty()) { + LOGS_DEFAULT(WARNING) << "Cannot load an EP-context model due to bad file path"; + } +} + +void VitisAIExecutionProvider::PrepareEPContextEnablement( + const onnxruntime::GraphViewer& graph_viewer) const { + if (model_path_str_.empty()) { + // TODO: platform dependency (Linux vs Windows). + model_path_str_ = ToPathString(GetTopLevelModelPath(graph_viewer).string()); + } + std::string backend_cache_dir, backend_cache_key; + get_backend_compilation_cache(model_path_str_, graph_viewer, info_, kXCCode, backend_cache_dir, backend_cache_key, backend_cache_data_); + info_["cacheDir"] = backend_cache_dir; + info_["cacheKey"] = backend_cache_key; + // Create a new model, reusing the graph name, the op-domain-to-opset-version map, + // the op schema registry of the current graph, etc. + p_ep_ctx_model_ = graph_viewer.CreateModel(*GetLogger()); + LOGS_DEFAULT(VERBOSE) << "Container model created"; +} + +void VitisAIExecutionProvider::FulfillEPContextEnablement( + const std::vector& fused_nodes_and_graphs) { + auto& ep_ctx_graph = p_ep_ctx_model_->MainGraph(); + if (!ep_ctx_embed_mode_) { + auto ep_ctx_cache_path_str = GetEPContextCacheFileLocation(ep_ctx_model_file_loc_, model_path_str_); + std::ofstream ep_ctx_cache_ofs(ep_ctx_cache_path_str.c_str(), std::ios::trunc); + if (!ep_ctx_cache_ofs.is_open()) { + ORT_THROW("Failed to open a file to write EP context cache: ", ep_ctx_cache_path_str.c_str()); + } + ep_ctx_cache_ofs.write(backend_cache_data_.c_str(), backend_cache_data_.length()); + if (!ep_ctx_cache_ofs.good()) { + ep_ctx_cache_ofs.close(); + ORT_THROW("Exception writing EP context cache file: ", ep_ctx_cache_path_str.c_str()); + } + ep_ctx_cache_ofs.close(); + CreateEPContexNodes(&ep_ctx_graph, fused_nodes_and_graphs, "", PathToUTF8String(ep_ctx_cache_path_str), 0, info_.at("cacheDir"), info_.at("cacheKey"), false, GetLogger()); + } else { + CreateEPContexNodes(&ep_ctx_graph, fused_nodes_and_graphs, backend_cache_data_, "", 1, info_["cacheDir"], info_["cacheKey"], false, GetLogger()); + } + if (GraphHasEPContextNode(ep_ctx_graph)) { + LOGS_DEFAULT(VERBOSE) << "Created model has EP context nodes"; + } else { + LOGS_DEFAULT(WARNING) << "No EP eontext nodes created"; + } +} + std::vector> VitisAIExecutionProvider::GetCapability( - const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { - if (graph.IsSubgraph()) { + const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const { + bool is_ep_ctx_model = GraphHasEPContextNode(graph_viewer.GetGraph()); + // TODO: platform dependency (Linux vs Windows). + model_path_str_ = ToPathString(GetTopLevelModelPath(graph_viewer).string()); + if (GetEPContextModelFileLocation( + ep_ctx_model_path_cfg_, model_path_str_, is_ep_ctx_model, ep_ctx_model_file_loc_)) { + if (is_ep_ctx_model) { + LOGS_DEFAULT(VERBOSE) << "An EP context model passed in"; + ValidateEPContextNode(graph_viewer.GetGraph()); + std::string cache_dir, cache_key; + RetrieveBackendCacheInfo(graph_viewer.GetGraph(), cache_dir, cache_key); + info_["cacheDir"] = cache_dir; + info_["cacheKey"] = cache_key; + LOGS_DEFAULT(VERBOSE) << "Trying getting compilation cache from " << PathToUTF8String(ep_ctx_model_file_loc_); + auto ep_ctx_payload = RetrieveEPContextCache(graph_viewer.GetGraph(), ep_ctx_model_file_loc_, false); + restore_backend_compilation_cache(cache_dir, cache_key, ep_ctx_payload, graph_viewer.ModelPath().string()); + } else { + if (fs::exists(ep_ctx_model_file_loc_) && fs::is_regular_file(ep_ctx_model_file_loc_) && ep_ctx_enabled_) { + ORT_THROW("The inference session was created with a normal ONNX model but a model file with EP context cache exists at ", + PathToUTF8String(ep_ctx_model_file_loc_), ". Please remove the EP context model manually if you want to re-generate it."); + // Disable the flexibility implemented below by throwing an exception. + // Now the code below is unreachable but DCE will take care of it. + // We might want to re-enable it in future, so we keep it as is. + LoadEPContexModelFromFile(); + ValidateEPContextNode(p_ep_ctx_model_->MainGraph()); + std::string cache_dir, cache_key; + RetrieveBackendCacheInfo(p_ep_ctx_model_->MainGraph(), cache_dir, cache_key); + info_["cacheDir"] = cache_dir; + info_["cacheKey"] = cache_key; + auto ep_ctx_payload = RetrieveEPContextCache(p_ep_ctx_model_->MainGraph(), ep_ctx_model_file_loc_, false); + restore_backend_compilation_cache(cache_dir, cache_key, ep_ctx_payload, graph_viewer.ModelPath().string()); + } + } + } else { + LOGS_DEFAULT(WARNING) << "Failed to get EP context model file location"; + } + + if (graph_viewer.IsSubgraph()) { // VITIS AI EP not support sungraph. Assigned to CPU. return {}; } @@ -40,13 +167,16 @@ std::vector> VitisAIExecutionProvider::GetCap // Only compiling a model once is currently supported return {}; } - execution_providers_ = std::make_unique(compile_onnx_model(graph, *GetLogger(), info_)); - auto result = vaip::GetComputeCapabilityOps(graph, execution_providers_.get(), vitisai_optypes_); + execution_providers_ = std::make_unique(compile_onnx_model(graph_viewer, *GetLogger(), info_)); + auto result = vaip::GetComputeCapabilityOps(graph_viewer, execution_providers_.get(), vitisai_optypes_); size_t index = 0u; for (auto& ep : **execution_providers_) { - result.emplace_back(vaip::XirSubgraphToComputeCapability1(graph, ep.get(), index)); + result.emplace_back(vaip::XirSubgraphToComputeCapability1(graph_viewer, ep.get(), index)); index = index + 1; } + if (ep_ctx_enabled_ && !is_ep_ctx_model) { + PrepareEPContextEnablement(graph_viewer); + } return result; } @@ -74,6 +204,10 @@ common::Status VitisAIExecutionProvider::Compile(const std::vector #include #include #include #include +// 1st-party headers/libs. +// #include "core/framework/session_options.h" #include "core/providers/shared_library/provider_api.h" #include "core/session/onnxruntime_c_api.h" +#include "core/common/inlined_containers_fwd.h" // we cannot include vaip/vaip.hpp here because header file referred by // onnxruntime_pybind_state_common.cc @@ -24,9 +28,11 @@ namespace onnxruntime { class VitisAIExecutionProvider : public IExecutionProvider { public: explicit VitisAIExecutionProvider(const ProviderOptions& info); + // explicit VitisAIExecutionProvider(const ProviderOptions& info, + // const SessionOptions* p_sess_opts = nullptr); ~VitisAIExecutionProvider() = default; - std::vector> GetCapability(const onnxruntime::GraphViewer& graph, + std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const override; int GetDeviceId() const { return 0; } @@ -35,16 +41,34 @@ class VitisAIExecutionProvider : public IExecutionProvider { std::vector& node_compute_funcs) override; std::shared_ptr GetKernelRegistry() const override; + // This method is called after both `GetComputeCapabilityOps()` and `Compile()`. + // This timing is required to work with both compliation-based EPs and non-compilation-based EPs. + const InlinedVector GetEpContextNodes() const override; + private: void CreateKernelRegistry(); using my_ep_t = vaip_core::DllSafe>>; using my_ep_uptr_t = std::shared_ptr; // we have to hide the implementation by forward declaration. mutable my_ep_uptr_t execution_providers_; - ProviderOptions info_; + mutable ProviderOptions info_; std::vector custom_op_domains_; std::shared_ptr registry_; std::set vitisai_optypes_; + // EP context related. + bool ep_ctx_enabled_ = false; + bool ep_ctx_embed_mode_ = true; + std::string ep_ctx_model_path_cfg_{""}; + mutable std::string backend_cache_data_{""}; + mutable PathString model_path_str_{}; + mutable PathString ep_ctx_model_file_loc_{}; + mutable std::unique_ptr p_ep_ctx_model_; + mutable std::unique_ptr p_ep_ctx_model_proto_; + // It might need to be called before loading + // the EP context model that is compiled AOT/offline. + void LoadEPContexModelFromFile() const; + void PrepareEPContextEnablement(const onnxruntime::GraphViewer&) const; + void FulfillEPContextEnablement(const std::vector&); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc old mode 100755 new mode 100644 diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index bd5a68152fb7..4f9669a7dcc4 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -28,6 +28,7 @@ #include "core/session/inference_session.h" #include "core/session/abi_session_options_impl.h" #include "core/session/ort_apis.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/provider_bridge_ort.h" #include "core/util/math.h" #include "core/framework/sparse_utils.h" @@ -68,10 +69,12 @@ using StringStringEntryProtos = google::protobuf::RepeatedPtrField; using TensorShapeProto_Dimensions = google::protobuf::RepeatedPtrField; using ValueInfoProtos = google::protobuf::RepeatedPtrField; +using FunctionProtos = google::protobuf::RepeatedPtrField; } // namespace ONNX_NAMESPACE namespace onnxruntime { using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef; +using IndexedSubGraph_SourceOfSchema = IndexedSubGraph::SourceOfSchema; } // namespace onnxruntime #include "core/common/cpuid_info.h" @@ -400,6 +403,11 @@ struct ProviderHostImpl : ProviderHost { int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) override { return p->size(); } ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) override { return p->at(index); }; + // OperatorSetIdProto + std::string* OperatorSetIdProto__mutable_domain(ONNX_NAMESPACE::OperatorSetIdProto* p) override { return p->mutable_domain(); } + void OperatorSetIdProto__set_version(ONNX_NAMESPACE::OperatorSetIdProto* p, int64_t version) override { return p->set_version(version); } + int64_t OperatorSetIdProto__version(const ONNX_NAMESPACE::OperatorSetIdProto* p) override { return p->version(); } + #if !defined(DISABLE_OPTIONAL_TYPE) // TypeProto_Optional (wrapped) const ONNX_NAMESPACE::TypeProto& TypeProto_Optional__elem_type(const ONNX_NAMESPACE::TypeProto_Optional* p) override { return p->elem_type(); } @@ -528,6 +536,11 @@ struct ProviderHostImpl : ProviderHost { void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) override { p->set_ir_version(value); } ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) override { return p->mutable_metadata_props(); }; + const ONNX_NAMESPACE::OperatorSetIdProto& ModelProto__opset_import(const ONNX_NAMESPACE::ModelProto* p, int index) override { return p->opset_import(index); } + ONNX_NAMESPACE::OperatorSetIdProto* ModelProto__mutable_opset_import(ONNX_NAMESPACE::ModelProto* p, int index) override { return p->mutable_opset_import(index); } + int ModelProto__opset_import_size(const ONNX_NAMESPACE::ModelProto* p) override { return p->opset_import_size(); } + ONNX_NAMESPACE::OperatorSetIdProto* ModelProto__add_opset_import(ONNX_NAMESPACE::ModelProto* p) override { return p->add_opset_import(); } + // NodeProto (wrapped) std::unique_ptr NodeProto__construct() override { return std::make_unique(); } void NodeProto__operator_delete(ONNX_NAMESPACE::NodeProto* p) override { delete p; } @@ -535,6 +548,7 @@ struct ProviderHostImpl : ProviderHost { int NodeProto__attribute_size(ONNX_NAMESPACE::NodeProto* p) override { return p->attribute_size(); } const ONNX_NAMESPACE::AttributeProto& NodeProto__attribute(const ONNX_NAMESPACE::NodeProto* p, int index) const override { return p->attribute(index); } ONNX_NAMESPACE::AttributeProto* NodeProto__mutable_attribute(ONNX_NAMESPACE::NodeProto* p, int index) override { return p->mutable_attribute(index); } + ONNX_NAMESPACE::AttributeProto* NodeProto__add_attribute(ONNX_NAMESPACE::NodeProto* p) override { return p->add_attribute(); } // TensorProto (wrapped) std::unique_ptr TensorProto__construct() override { return std::make_unique(); } @@ -609,6 +623,64 @@ struct ProviderHostImpl : ProviderHost { const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) override { return (*p)[index]; } + // FunctionProto (wrapped) + std::unique_ptr FunctionProto__construct() override { return std::make_unique(); } + void FunctionProto__operator_delete(ONNX_NAMESPACE::FunctionProto* p) override { delete p; } + + bool FunctionProto__SerializeToString(const ONNX_NAMESPACE::FunctionProto* p, std::string& string) override { return p->SerializeToString(&string); } + bool FunctionProto__SerializeToOstream(const ONNX_NAMESPACE::FunctionProto* p, std::ostream& output) override { return p->SerializeToOstream(&output); } + bool FunctionProto__ParseFromString(ONNX_NAMESPACE::FunctionProto* p, const std::string& data) override { return p->ParseFromString(data); } + std::string FunctionProto__SerializeAsString(const ONNX_NAMESPACE::FunctionProto* p) override { return p->SerializeAsString(); } + + bool FunctionProto__has_name(const ONNX_NAMESPACE::FunctionProto* p) override { return p->has_name(); } + const std::string& FunctionProto__name(const ONNX_NAMESPACE::FunctionProto* p) const override { return p->name(); } + void FunctionProto__set_name(ONNX_NAMESPACE::FunctionProto* p, const std::string& name) override { p->set_name(name); } + + bool FunctionProto__has_doc_string(const ONNX_NAMESPACE::FunctionProto* p) override { return p->has_doc_string(); } + const std::string& FunctionProto__doc_string(const ONNX_NAMESPACE::FunctionProto* p) const override { return p->doc_string(); } + void FunctionProto__set_doc_string(ONNX_NAMESPACE::FunctionProto* p, const std::string& doc_string) override { p->set_doc_string(doc_string); } + + bool FunctionProto__has_domain(const ONNX_NAMESPACE::FunctionProto* p) override { return p->has_domain(); } + const std::string& FunctionProto__domain(const ONNX_NAMESPACE::FunctionProto* p) const override { return p->domain(); } + void FunctionProto__set_domain(ONNX_NAMESPACE::FunctionProto* p, const std::string& domain) override { p->set_domain(domain); } + + const std::string& FunctionProto__input(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->input(index); } + std::string* FunctionProto__mutable_input(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_input(index); } + int FunctionProto__input_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->input_size(); } + void FunctionProto__add_input(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) override { p->add_input(value); } + + const std::string& FunctionProto__output(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->output(index); } + std::string* FunctionProto__mutable_output(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_output(index); } + int FunctionProto__output_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->output_size(); } + void FunctionProto__add_output(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) override { p->add_output(value); } + + const std::string& FunctionProto__attribute(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->attribute(index); } + std::string* FunctionProto__mutable_attribute(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_attribute(index); } + int FunctionProto__attribute_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->attribute_size(); } + void FunctionProto__add_attribute(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) override { p->add_attribute(value); } + + const ONNX_NAMESPACE::AttributeProto& FunctionProto__attribute_proto(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->attribute_proto(index); } + ONNX_NAMESPACE::AttributeProto* FunctionProto__mutable_attribute_proto(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_attribute_proto(index); } + int FunctionProto__attribute_proto_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->attribute_proto_size(); } + ONNX_NAMESPACE::AttributeProto* FunctionProto__add_attribute_proto(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_attribute_proto(); } + + const ONNX_NAMESPACE::NodeProto& FunctionProto__node(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->node(index); } + ONNX_NAMESPACE::NodeProto* FunctionProto__mutable_node(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_node(index); } + int FunctionProto__node_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->node_size(); } + ONNX_NAMESPACE::NodeProto* FunctionProto__add_node(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_node(); } + + const ONNX_NAMESPACE::ValueInfoProto& FunctionProto__value_info(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->value_info(index); } + ONNX_NAMESPACE::ValueInfoProto* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_value_info(index); } + ONNX_NAMESPACE::ValueInfoProtos* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p) override { return p->mutable_value_info(); } + int FunctionProto__value_info_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->value_info_size(); } + ONNX_NAMESPACE::ValueInfoProto* FunctionProto__add_value_info(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_value_info(); } + + const ONNX_NAMESPACE::StringStringEntryProto& FunctionProto__metadata_props(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->metadata_props(index); } + ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_metadata_props(index); } + ONNX_NAMESPACE::StringStringEntryProtos* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p) override { return p->mutable_metadata_props(); } + int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->metadata_props_size(); } + ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_metadata_props(); } + static int32_t convert_elem_type(const ONNX_NAMESPACE::AttributeProto* data_type) { int32_t elemType = 0; if (data_type->s() == "float32") { @@ -791,9 +863,12 @@ struct ProviderHostImpl : ProviderHost { std::vector& IndexedSubGraph__Nodes(IndexedSubGraph* p) override { return p->nodes; } - void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr&& meta_def_) override { return p->SetMetaDef(std::move(meta_def_)); } + void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr&& meta_def_) override { p->SetMetaDef(std::move(meta_def_)); } const IndexedSubGraph_MetaDef* IndexedSubGraph__GetMetaDef(const IndexedSubGraph* p) override { return p->GetMetaDef(); } + void IndexedSubGraph__SetSchemaSource(IndexedSubGraph* p, IndexedSubGraph_SourceOfSchema schema_source) override { p->schema_source = schema_source; } + IndexedSubGraph_SourceOfSchema IndexedSubGraph__GetSchemaSource(const IndexedSubGraph* p) override { return p->schema_source; } + // KernelDef (wrapped) void KernelDef__operator_delete(KernelDef* p) override { delete p; } void KernelDef__SinceVersion(const KernelDef* p, int* start, int* end) override { return p->SinceVersion(start, end); } @@ -2842,6 +2917,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_ provider_options[provider_options_keys[i]] = provider_options_values[i]; } + // EP context related session config options. + provider_options["ep_context_enable"] = options->value.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0"); + provider_options["ep_context_embed_mode"] = options->value.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1"); + provider_options["ep_context_file_path"] = options->value.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + auto factory = onnxruntime::VitisAIProviderFactoryCreator::Create(provider_options); if (!factory) { return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_VitisAI: Failed to load shared library"); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index e539614fd6d1..e13285c60e69 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1114,6 +1114,9 @@ std::unique_ptr CreateExecutionProviderInstance( if (it != provider_options_map.end()) { info = it->second; } + info["ep_context_enable"] = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0"); + info["ep_context_embed_mode"] = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1"); + info["ep_context_file_path"] = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); return onnxruntime::VitisAIProviderFactoryCreator::Create(info)->CreateProvider(); #endif } else if (type == kAclExecutionProvider) {