Skip to content

Commit

Permalink
Optimize Decoder Pipeline Model Execution (#907)
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Sep 24, 2024
1 parent 6b5194e commit 2348dc9
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 12 deletions.
2 changes: 2 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ struct SessionOptions_Element : JSON::Element {
v_.enable_quant_qdq_cleanup = value;
else if (name == "ep_context_enable")
v_.ep_context_enable = value;
else if (name == "use_env_allocators")
v_.use_env_allocators = value;
else
throw JSON::unknown_value_error{};
}
Expand Down
1 change: 1 addition & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct Config {
std::optional<std::string> log_id;
std::optional<int> log_severity_level;
std::optional<std::string> enable_profiling;
bool use_env_allocators{true};

std::vector<ProviderOptions> provider_options;
};
Expand Down
7 changes: 6 additions & 1 deletion src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ namespace Generators {

static bool _ = (Ort::InitApi(), false);

OrtGlobals::OrtGlobals() : env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} {}
OrtGlobals::OrtGlobals()
: env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} {
auto arena_config = OrtArenaCfg::Create(0, -1, -1, -1);
Ort::Allocator& allocator_cpu{Ort::Allocator::GetWithDefaultOptions()};
env_->CreateAndRegisterAllocator(allocator_cpu.GetInfo(), *arena_config);
}

// Ensure Shutdown() has been called before process exit
struct ValidateShutdown {
Expand Down
60 changes: 50 additions & 10 deletions src/models/decoder_only_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,13 @@ RoamingArray<float> DecoderOnlyPipelineState::Run(int current_length, RoamingArr
continue;
}

// Clear the intermediate pipeline state from previous runs.
// Clear the intermediate pipeline state outputs from the previous runs.
// These outputs will be replaced by the outputs from the current run.
for (const auto& output_name : pipeline_state->output_names_) {
if (auto iter = ortvalue_store_.find(output_name); iter != ortvalue_store_.end()) {
ortvalue_store_.erase(iter);
}
}
pipeline_state->ClearIO();

// Managed inputs and outputs are those inputs and outputs that the
Expand All @@ -143,10 +149,10 @@ RoamingArray<float> DecoderOnlyPipelineState::Run(int current_length, RoamingArr
}

// Add outputs from the previous pipeline states to the current pipeline state
for (auto& [name, ortvalue] : ortvalue_pool_) {
for (auto& [name, ortvalue] : ortvalue_store_) {
if (pipeline_state->HasInput(name)) {
pipeline_state->input_names_.push_back(name.c_str());
pipeline_state->inputs_.push_back(ortvalue);
pipeline_state->inputs_.push_back(ortvalue.get());
}
}

Expand All @@ -167,6 +173,25 @@ RoamingArray<float> DecoderOnlyPipelineState::Run(int current_length, RoamingArr
}
}

// Output of pipeline models could also be managed inputs.
// For example, the output of a pipeline model could be the key-value cache.
// In such cases, use the managed output buffers and register them with the pipeline model as outputs.
for (const auto& input_name : input_names_) {
if (pipeline_state->HasOutput(input_name)) {
if (!pipeline_state->SupportsPrimaryDevice()) {
std::ostringstream oss;
oss << "Managed input " << input_name << " resides on the primary device type ("
<< to_string(model_.device_type_) << "). "
<< "But the pipeline model "
<< model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id
<< " is expecting it to reside elsewhere.";
throw std::runtime_error(oss.str());
}
pipeline_state->output_names_.push_back(input_name);
pipeline_state->outputs_.push_back(State::GetInput(input_name));
}
}

// Add all the remaining outputs for the intermediate pipeline state
for (const auto& output_name : model_.config_->model.decoder.pipeline[pipeline_state->id_].outputs) {
if (std::none_of(pipeline_state->output_names_.begin(), pipeline_state->output_names_.end(),
Expand All @@ -179,16 +204,31 @@ RoamingArray<float> DecoderOnlyPipelineState::Run(int current_length, RoamingArr
// Run the intermediate pipeline state
pipeline_state->Run(current_length, next_tokens, next_indices);

// Store the non-managed outputs from the current pipeline state in the ortvalue pool.
// Transfer ownership of all the non-managed outputs from the current pipeline state to the ortvalue store.
// All non managed outputs are assumed to be on CPU
for (size_t i = 0; i < pipeline_state->output_names_.size(); ++i) {
if (std::none_of(output_names_.begin(), output_names_.end(),
[&](const std::string& elem) { return elem == pipeline_state->output_names_[i]; }) &&
std::none_of(input_names_.begin(), input_names_.end(),
[&](const std::string& elem) { return elem == pipeline_state->output_names_[i]; })) {
auto forwarded_output = model_.config_->model.decoder.pipeline[pipeline_state->id_].output_names_forwarder.find(pipeline_state->output_names_[i]);
if (forwarded_output != model_.config_->model.decoder.pipeline[pipeline_state->id_].output_names_forwarder.end()) {
ortvalue_pool_[forwarded_output->second] = pipeline_state->outputs_[i];
ortvalue_store_[forwarded_output->second] = std::unique_ptr<OrtValue>(pipeline_state->outputs_[i]);
} else {
ortvalue_pool_[pipeline_state->output_names_[i]] = pipeline_state->outputs_[i];
ortvalue_store_[pipeline_state->output_names_[i]] = std::unique_ptr<OrtValue>(pipeline_state->outputs_[i]);
}
}
}
}

// Clear the outputs of the pipeline models that are only run on prompt since this cannot happen earlier.
if (!first_run_) {
for (auto& pipeline_state : pipeline_states_) {
if (!model_.config_->model.decoder.pipeline[pipeline_state->id_].run_on_token_gen) {
for (const auto& output_name : pipeline_state->output_names_) {
if (auto iter = ortvalue_store_.find(output_name); iter != ortvalue_store_.end()) {
ortvalue_store_.erase(iter);
}
}
}
}
Expand All @@ -208,10 +248,10 @@ void DecoderOnlyPipelineState::UpdateInputsOutputs(const RoamingArray<int32_t>&
}

OrtValue* DecoderOnlyPipelineState::GetOutput(const char* name) {
// Check the ortvalue pool to search if name is one of the non-managed output.
auto it = ortvalue_pool_.find(name);
if (it != ortvalue_pool_.end()) {
return it->second;
// Check the ortvalue store to search if name is one of the non-managed output.
auto it = ortvalue_store_.find(name);
if (it != ortvalue_store_.end()) {
return it->second.get();
}

// Search managed outputs saved in this State.
Expand Down
2 changes: 1 addition & 1 deletion src/models/decoder_only_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct DecoderOnlyPipelineState : State {
std::vector<std::unique_ptr<IntermediatePipelineState>> pipeline_states_;

// Stores all the outputs from the previous pipeline state(s)
std::unordered_map<std::string, OrtValue*> ortvalue_pool_;
std::unordered_map<std::string, std::unique_ptr<OrtValue>> ortvalue_store_;

InputIDs input_ids_{model_, *this};
Logits logits_{model_, *this};
Expand Down
5 changes: 5 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,11 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
session_options.SetEpContextFilePath(config_session_options.ep_context_file_path.value().c_str());
}

if (config_session_options.provider_options.empty() && config_session_options.use_env_allocators) {
// Share env allocators across sessions that only use the CPU provider
session_options.AddConfigEntry("session.use_env_allocators", "1");
}

for (auto& provider_options : config_session_options.provider_options) {
if (provider_options.name == "cuda") {
auto ort_provider_options = OrtCUDAProviderOptionsV2::Create();
Expand Down

0 comments on commit 2348dc9

Please sign in to comment.