From 2348dc90beed7eacbdb458dcf682d879decb132b Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 24 Sep 2024 16:31:52 -0700 Subject: [PATCH] Optimize Decoder Pipeline Model Execution (#907) --- src/config.cpp | 2 + src/config.h | 1 + src/generators.cpp | 7 +++- src/models/decoder_only_pipeline.cpp | 60 +++++++++++++++++++++++----- src/models/decoder_only_pipeline.h | 2 +- src/models/model.cpp | 5 +++ 6 files changed, 65 insertions(+), 12 deletions(-) diff --git a/src/config.cpp b/src/config.cpp index 67889f618..1e07e7f9f 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -96,6 +96,8 @@ struct SessionOptions_Element : JSON::Element { v_.enable_quant_qdq_cleanup = value; else if (name == "ep_context_enable") v_.ep_context_enable = value; + else if (name == "use_env_allocators") + v_.use_env_allocators = value; else throw JSON::unknown_value_error{}; } diff --git a/src/config.h b/src/config.h index 1a928576a..77e84fafc 100644 --- a/src/config.h +++ b/src/config.h @@ -40,6 +40,7 @@ struct Config { std::optional log_id; std::optional log_severity_level; std::optional enable_profiling; + bool use_env_allocators{true}; std::vector provider_options; }; diff --git a/src/generators.cpp b/src/generators.cpp index f0cb710c7..46443a660 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -13,7 +13,12 @@ namespace Generators { static bool _ = (Ort::InitApi(), false); -OrtGlobals::OrtGlobals() : env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} {} +OrtGlobals::OrtGlobals() + : env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} { + auto arena_config = OrtArenaCfg::Create(0, -1, -1, -1); + Ort::Allocator& allocator_cpu{Ort::Allocator::GetWithDefaultOptions()}; + env_->CreateAndRegisterAllocator(allocator_cpu.GetInfo(), *arena_config); +} // Ensure Shutdown() has been called before process exit struct ValidateShutdown { diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index 4e936e7d4..dbe90a2c0 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -119,7 +119,13 @@ RoamingArray DecoderOnlyPipelineState::Run(int current_length, RoamingArr continue; } - // Clear the intermediate pipeline state from previous runs. + // Clear the intermediate pipeline state outputs from the previous runs. + // These outputs will be replaced by the outputs from the current run. + for (const auto& output_name : pipeline_state->output_names_) { + if (auto iter = ortvalue_store_.find(output_name); iter != ortvalue_store_.end()) { + ortvalue_store_.erase(iter); + } + } pipeline_state->ClearIO(); // Managed inputs and outputs are those inputs and outputs that the @@ -143,10 +149,10 @@ RoamingArray DecoderOnlyPipelineState::Run(int current_length, RoamingArr } // Add outputs from the previous pipeline states to the current pipeline state - for (auto& [name, ortvalue] : ortvalue_pool_) { + for (auto& [name, ortvalue] : ortvalue_store_) { if (pipeline_state->HasInput(name)) { pipeline_state->input_names_.push_back(name.c_str()); - pipeline_state->inputs_.push_back(ortvalue); + pipeline_state->inputs_.push_back(ortvalue.get()); } } @@ -167,6 +173,25 @@ RoamingArray DecoderOnlyPipelineState::Run(int current_length, RoamingArr } } + // Output of pipeline models could also be managed inputs. + // For example, the output of a pipeline model could be the key-value cache. + // In such cases, use the managed output buffers and register them with the pipeline model as outputs. + for (const auto& input_name : input_names_) { + if (pipeline_state->HasOutput(input_name)) { + if (!pipeline_state->SupportsPrimaryDevice()) { + std::ostringstream oss; + oss << "Managed input " << input_name << " resides on the primary device type (" + << to_string(model_.device_type_) << "). " + << "But the pipeline model " + << model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id + << " is expecting it to reside elsewhere."; + throw std::runtime_error(oss.str()); + } + pipeline_state->output_names_.push_back(input_name); + pipeline_state->outputs_.push_back(State::GetInput(input_name)); + } + } + // Add all the remaining outputs for the intermediate pipeline state for (const auto& output_name : model_.config_->model.decoder.pipeline[pipeline_state->id_].outputs) { if (std::none_of(pipeline_state->output_names_.begin(), pipeline_state->output_names_.end(), @@ -179,16 +204,31 @@ RoamingArray DecoderOnlyPipelineState::Run(int current_length, RoamingArr // Run the intermediate pipeline state pipeline_state->Run(current_length, next_tokens, next_indices); - // Store the non-managed outputs from the current pipeline state in the ortvalue pool. + // Transfer ownership of all the non-managed outputs from the current pipeline state to the ortvalue store. // All non managed outputs are assumed to be on CPU for (size_t i = 0; i < pipeline_state->output_names_.size(); ++i) { if (std::none_of(output_names_.begin(), output_names_.end(), + [&](const std::string& elem) { return elem == pipeline_state->output_names_[i]; }) && + std::none_of(input_names_.begin(), input_names_.end(), [&](const std::string& elem) { return elem == pipeline_state->output_names_[i]; })) { auto forwarded_output = model_.config_->model.decoder.pipeline[pipeline_state->id_].output_names_forwarder.find(pipeline_state->output_names_[i]); if (forwarded_output != model_.config_->model.decoder.pipeline[pipeline_state->id_].output_names_forwarder.end()) { - ortvalue_pool_[forwarded_output->second] = pipeline_state->outputs_[i]; + ortvalue_store_[forwarded_output->second] = std::unique_ptr(pipeline_state->outputs_[i]); } else { - ortvalue_pool_[pipeline_state->output_names_[i]] = pipeline_state->outputs_[i]; + ortvalue_store_[pipeline_state->output_names_[i]] = std::unique_ptr(pipeline_state->outputs_[i]); + } + } + } + } + + // Clear the outputs of the pipeline models that are only run on prompt since this cannot happen earlier. + if (!first_run_) { + for (auto& pipeline_state : pipeline_states_) { + if (!model_.config_->model.decoder.pipeline[pipeline_state->id_].run_on_token_gen) { + for (const auto& output_name : pipeline_state->output_names_) { + if (auto iter = ortvalue_store_.find(output_name); iter != ortvalue_store_.end()) { + ortvalue_store_.erase(iter); + } } } } @@ -208,10 +248,10 @@ void DecoderOnlyPipelineState::UpdateInputsOutputs(const RoamingArray& } OrtValue* DecoderOnlyPipelineState::GetOutput(const char* name) { - // Check the ortvalue pool to search if name is one of the non-managed output. - auto it = ortvalue_pool_.find(name); - if (it != ortvalue_pool_.end()) { - return it->second; + // Check the ortvalue store to search if name is one of the non-managed output. + auto it = ortvalue_store_.find(name); + if (it != ortvalue_store_.end()) { + return it->second.get(); } // Search managed outputs saved in this State. diff --git a/src/models/decoder_only_pipeline.h b/src/models/decoder_only_pipeline.h index b34173cef..948291612 100644 --- a/src/models/decoder_only_pipeline.h +++ b/src/models/decoder_only_pipeline.h @@ -66,7 +66,7 @@ struct DecoderOnlyPipelineState : State { std::vector> pipeline_states_; // Stores all the outputs from the previous pipeline state(s) - std::unordered_map ortvalue_pool_; + std::unordered_map> ortvalue_store_; InputIDs input_ids_{model_, *this}; Logits logits_{model_, *this}; diff --git a/src/models/model.cpp b/src/models/model.cpp index a83d7c2a6..1fa9a5671 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -364,6 +364,11 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ session_options.SetEpContextFilePath(config_session_options.ep_context_file_path.value().c_str()); } + if (config_session_options.provider_options.empty() && config_session_options.use_env_allocators) { + // Share env allocators across sessions that only use the CPU provider + session_options.AddConfigEntry("session.use_env_allocators", "1"); + } + for (auto& provider_options : config_session_options.provider_options) { if (provider_options.name == "cuda") { auto ort_provider_options = OrtCUDAProviderOptionsV2::Create();