Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Decoder Pipeline Model Execution #907

Merged
merged 7 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ namespace Generators {

static bool _ = (Ort::InitApi(), false);

OrtGlobals::OrtGlobals() : env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} {}
OrtGlobals::OrtGlobals()
: env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)},
arena_config_{OrtArenaCfg::Create(0, -1, -1, -1)} {
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
Ort::Allocator& allocator_cpu{Ort::Allocator::GetWithDefaultOptions()};
env_->CreateAndRegisterAllocator(allocator_cpu.GetInfo(), *arena_config_);
}

// Ensure Shutdown() has been called before process exit
struct ValidateShutdown {
Expand Down
1 change: 1 addition & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ struct OrtGlobals {
OrtGlobals();

std::unique_ptr<OrtEnv> env_;
std::unique_ptr<OrtArenaCfg> arena_config_;
#if USE_CUDA
std::unique_ptr<OrtMemoryInfo> memory_info_cuda_;
std::unique_ptr<Ort::Allocator> allocator_cuda_;
Expand Down
60 changes: 50 additions & 10 deletions src/models/decoder_only_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,13 @@ RoamingArray<float> DecoderOnlyPipelineState::Run(int current_length, RoamingArr
continue;
}

// Clear the intermediate pipeline state from previous runs.
// Clear the intermediate pipeline state outputs from the previous runs.
// These outputs will be replaced by the outputs from the current run.
for (const auto& output_name : pipeline_state->output_names_) {
if (auto iter = ortvalue_store_.find(output_name); iter != ortvalue_store_.end()) {
ortvalue_store_.erase(iter);
}
}
pipeline_state->ClearIO();

// Managed inputs and outputs are those inputs and outputs that the
Expand All @@ -143,10 +149,10 @@ RoamingArray<float> DecoderOnlyPipelineState::Run(int current_length, RoamingArr
}

// Add outputs from the previous pipeline states to the current pipeline state
for (auto& [name, ortvalue] : ortvalue_pool_) {
for (auto& [name, ortvalue] : ortvalue_store_) {
if (pipeline_state->HasInput(name)) {
pipeline_state->input_names_.push_back(name.c_str());
pipeline_state->inputs_.push_back(ortvalue);
pipeline_state->inputs_.push_back(ortvalue.get());
}
}

Expand All @@ -167,6 +173,25 @@ RoamingArray<float> DecoderOnlyPipelineState::Run(int current_length, RoamingArr
}
}

// Output of pipeline models could also be managed inputs.
// For example, the output of a pipeline model could be the key-value cache.
// In such cases, use the managed output buffers and register them with the pipeline model as outputs.
for (const auto& input_name : input_names_) {
if (pipeline_state->HasOutput(input_name)) {
if (!pipeline_state->SupportsPrimaryDevice()) {
std::ostringstream oss;
oss << "Managed input " << input_name << " resides on the primary device type ("
<< to_string(model_.device_type_) << "). "
<< "But the pipeline model "
<< model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id
<< " is expecting it to reside elsewhere.";
throw std::runtime_error(oss.str());
}
pipeline_state->output_names_.push_back(input_name);
pipeline_state->outputs_.push_back(State::GetInput(input_name));
}
}

// Add all the remaining outputs for the intermediate pipeline state
for (const auto& output_name : model_.config_->model.decoder.pipeline[pipeline_state->id_].outputs) {
if (std::none_of(pipeline_state->output_names_.begin(), pipeline_state->output_names_.end(),
Expand All @@ -179,16 +204,31 @@ RoamingArray<float> DecoderOnlyPipelineState::Run(int current_length, RoamingArr
// Run the intermediate pipeline state
pipeline_state->Run(current_length, next_tokens, next_indices);

// Store the non-managed outputs from the current pipeline state in the ortvalue pool.
// Transfer ownership of all the non-managed outputs from the current pipeline state to the ortvalue store.
// All non managed outputs are assumed to be on CPU
for (size_t i = 0; i < pipeline_state->output_names_.size(); ++i) {
if (std::none_of(output_names_.begin(), output_names_.end(),
[&](const std::string& elem) { return elem == pipeline_state->output_names_[i]; }) &&
std::none_of(input_names_.begin(), input_names_.end(),
edgchen1 marked this conversation as resolved.
Show resolved Hide resolved
[&](const std::string& elem) { return elem == pipeline_state->output_names_[i]; })) {
auto forwarded_output = model_.config_->model.decoder.pipeline[pipeline_state->id_].output_names_forwarder.find(pipeline_state->output_names_[i]);
if (forwarded_output != model_.config_->model.decoder.pipeline[pipeline_state->id_].output_names_forwarder.end()) {
ortvalue_pool_[forwarded_output->second] = pipeline_state->outputs_[i];
ortvalue_store_[forwarded_output->second] = std::unique_ptr<OrtValue>(pipeline_state->outputs_[i]);
} else {
ortvalue_pool_[pipeline_state->output_names_[i]] = pipeline_state->outputs_[i];
ortvalue_store_[pipeline_state->output_names_[i]] = std::unique_ptr<OrtValue>(pipeline_state->outputs_[i]);
}
}
}
}

// Clear the outputs of the pipeline models that are only run on prompt since this cannot happen earlier.
if (!first_run_) {
yufenglee marked this conversation as resolved.
Show resolved Hide resolved
for (auto& pipeline_state : pipeline_states_) {
if (!model_.config_->model.decoder.pipeline[pipeline_state->id_].run_on_token_gen) {
for (const auto& output_name : pipeline_state->output_names_) {
if (auto iter = ortvalue_store_.find(output_name); iter != ortvalue_store_.end()) {
ortvalue_store_.erase(iter);
}
}
}
}
Expand All @@ -208,10 +248,10 @@ void DecoderOnlyPipelineState::UpdateInputsOutputs(const RoamingArray<int32_t>&
}

OrtValue* DecoderOnlyPipelineState::GetOutput(const char* name) {
// Check the ortvalue pool to search if name is one of the non-managed output.
auto it = ortvalue_pool_.find(name);
if (it != ortvalue_pool_.end()) {
return it->second;
// Check the ortvalue store to search if name is one of the non-managed output.
auto it = ortvalue_store_.find(name);
if (it != ortvalue_store_.end()) {
return it->second.get();
}

// Search managed outputs saved in this State.
Expand Down
2 changes: 1 addition & 1 deletion src/models/decoder_only_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct DecoderOnlyPipelineState : State {
std::vector<std::unique_ptr<IntermediatePipelineState>> pipeline_states_;

// Stores all the outputs from the previous pipeline state(s)
std::unordered_map<std::string, OrtValue*> ortvalue_pool_;
std::unordered_map<std::string, std::unique_ptr<OrtValue>> ortvalue_store_;

InputIDs input_ids_{model_, *this};
Logits logits_{model_, *this};
Expand Down
5 changes: 5 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,11 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
session_options.SetEpContextFilePath(config_session_options.ep_context_file_path.value().c_str());
}

if (config_session_options.provider_options.empty()) {
// Share env allocators across sessions that only use the CPU provider
session_options.AddConfigEntry("session.use_env_allocators", "1");
}

for (auto& provider_options : config_session_options.provider_options) {
if (provider_options.name == "cuda") {
auto ort_provider_options = OrtCUDAProviderOptionsV2::Create();
Expand Down
Loading