Skip to content

Commit

Permalink
Simplify Model IO handler constructors (#925)
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanUnderhill authored Sep 26, 2024
1 parent 1e4d289 commit 205e997
Show file tree
Hide file tree
Showing 20 changed files with 78 additions and 86 deletions.
8 changes: 4 additions & 4 deletions src/models/decoder_only.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ struct DecoderOnly_State : State {
const DecoderOnly_Model& model_;
CapturedGraphInfoPtr captured_graph_info_;

InputIDs input_ids_{model_, *this};
Logits logits_{model_, *this};
KV_Cache kv_cache_{model_, *this};
InputIDs input_ids_{*this};
Logits logits_{*this};
KV_Cache kv_cache_{*this};
PositionInputs position_inputs_;
ExtraInputs extra_inputs_{model_, *this};
ExtraInputs extra_inputs_{*this};
};

} // namespace Generators
2 changes: 1 addition & 1 deletion src/models/decoder_only_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode
position_inputs_.Add();
logits_.Add();
if (KV_Cache::IsCacheNeeded(model)) {
kv_cache_ = std::make_unique<KV_Cache>(model, *this);
kv_cache_ = std::make_unique<KV_Cache>(*this);
kv_cache_->Add();
}
extra_inputs_.Add();
Expand Down
6 changes: 3 additions & 3 deletions src/models/decoder_only_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ struct DecoderOnlyPipelineState : State {
// Stores all the outputs from the previous pipeline state(s)
std::unordered_map<std::string, std::unique_ptr<OrtValue>> ortvalue_store_;

InputIDs input_ids_{model_, *this};
Logits logits_{model_, *this};
InputIDs input_ids_{*this};
Logits logits_{*this};
std::unique_ptr<KV_Cache> kv_cache_;
PositionInputs position_inputs_;
ExtraInputs extra_inputs_{model_, *this};
ExtraInputs extra_inputs_{*this};
};

} // namespace Generators
5 changes: 2 additions & 3 deletions src/models/embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

namespace Generators {

Embeddings::Embeddings(const Model& model, State& state, Embeddings::Mode mode, const std::string& name)
: model_{model},
state_{state},
Embeddings::Embeddings(State& state, Embeddings::Mode mode, const std::string& name)
: state_{state},
shape_{static_cast<int64_t>(state_.params_->batch_size) * state_.params_->search.num_beams,
state_.params_->sequence_length, model_.config_->model.decoder.hidden_size},
type_{mode == Embeddings::Mode::Input
Expand Down
4 changes: 2 additions & 2 deletions src/models/embeddings.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct Embeddings {
Output
};

Embeddings(const Model& model, State& state, Embeddings::Mode mode, const std::string& name);
Embeddings(State& state, Embeddings::Mode mode, const std::string& name);
Embeddings(const Embeddings&) = delete;
Embeddings& operator=(const Embeddings&) = delete;

Expand All @@ -26,8 +26,8 @@ struct Embeddings {
auto& GetShape() const { return shape_; }

private:
const Model& model_;
State& state_;
const Model& model_{state_.model_};
std::array<int64_t, 3> shape_{}; // [batch_size, sequence_length, hidden_size]
ONNXTensorElementDataType type_;
const Mode mode_{};
Expand Down
5 changes: 2 additions & 3 deletions src/models/extra_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

namespace Generators {

ExtraInputs::ExtraInputs(const Model& model, State& state)
: model_{model},
state_{state} {
ExtraInputs::ExtraInputs(State& state)
: state_{state} {
extra_inputs_.reserve(state_.params_->extra_inputs.size());

if (state_.GetCapturedGraphInfo()) {
Expand Down
4 changes: 2 additions & 2 deletions src/models/extra_inputs.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
namespace Generators {

struct ExtraInputs {
ExtraInputs(const Model& model, State& state);
ExtraInputs(State& state);
void Add();

private:
const Model& model_;
State& state_;
const Model& model_{state_.model_};
std::vector<OrtValue*> extra_inputs_;
std::vector<std::unique_ptr<OrtValue>> owned_extra_inputs_;
std::unordered_map<std::string, StaticBuffer*> sb_extra_inputs_;
Expand Down
8 changes: 4 additions & 4 deletions src/models/gpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ struct Gpt_State : State {

const Gpt_Model& model_;

InputIDs input_ids_{model_, *this};
Logits logits_{model_, *this};
KV_Cache_Combined kv_cache_{model_, *this};
InputIDs input_ids_{*this};
Logits logits_{*this};
KV_Cache_Combined kv_cache_{*this};
PositionInputs position_inputs_;
ExtraInputs extra_inputs_{model_, *this};
ExtraInputs extra_inputs_{*this};
};
} // namespace Generators
7 changes: 3 additions & 4 deletions src/models/image_features.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

namespace Generators {

ImageFeatures::ImageFeatures(const Model& model, State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens)
: model_{model},
state_{state},
shape_{num_image_tokens, model.config_->model.decoder.hidden_size},
ImageFeatures::ImageFeatures(State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens)
: state_{state},
shape_{num_image_tokens, model_.config_->model.decoder.hidden_size},
type_{mode == ImageFeatures::Mode::Input
? model_.session_info_->GetInputDataType(name)
: model_.session_info_->GetOutputDataType(name)},
Expand Down
4 changes: 2 additions & 2 deletions src/models/image_features.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct ImageFeatures {
Output
};

ImageFeatures(const Model& model, State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens);
ImageFeatures(State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens);
ImageFeatures(const ImageFeatures&) = delete;
ImageFeatures& operator=(const ImageFeatures&) = delete;

Expand All @@ -23,8 +23,8 @@ struct ImageFeatures {
OrtValue* Get() { return image_features_.get(); }

private:
const Model& model_;
State& state_;
const Model& model_{state_.model_};

std::array<int64_t, 2> shape_{}; // [num_image_tokens, hidden_size]
ONNXTensorElementDataType type_;
Expand Down
9 changes: 4 additions & 5 deletions src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,23 @@

namespace Generators {

InputIDs::InputIDs(const Model& model, State& state)
: model_{model},
state_{state} {
InputIDs::InputIDs(State& state)
: state_{state} {
name_ = model_.config_->model.decoder.inputs.input_ids.c_str();
shape_ = {state_.params_->batch_size, state_.params_->sequence_length};
type_ = model_.session_info_->GetInputDataType(name_);

// If 64-bit, convert from 32-bit to 64-bit
if (type_ == Ort::TypeToTensorType<int64_t>) {
value_ = OrtValue::CreateTensor(model.allocator_cpu_, shape_, type_);
value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_);
auto* p_data = value_->GetTensorMutableData<int64_t>();
for (auto v : state_.params_->input_ids) {
*p_data++ = v;
}
} else {
if (type_ != Ort::TypeToTensorType<int32_t>)
throw std::runtime_error("InputIDs must be int64 or int32");
value_ = OrtValue::CreateTensor<int32_t>(model.allocator_cpu_.GetInfo(), std::span<int32_t>(const_cast<int32_t*>(state_.params_->input_ids.data()), shape_[0] * shape_[1]), shape_);
value_ = OrtValue::CreateTensor<int32_t>(model_.allocator_cpu_.GetInfo(), std::span<int32_t>(const_cast<int32_t*>(state_.params_->input_ids.data()), shape_[0] * shape_[1]), shape_);
}

value_ = model_.ExpandInputs(value_, state_.params_->search.num_beams);
Expand Down
4 changes: 2 additions & 2 deletions src/models/input_ids.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace Generators {

struct InputIDs {
InputIDs(const Model& model, State& state);
InputIDs(State& state);
InputIDs(const InputIDs&) = delete;
InputIDs& operator=(const InputIDs&) = delete;

Expand All @@ -18,8 +18,8 @@ struct InputIDs {
OrtValue* Get() { return value_.get(); }

private:
const Model& model_;
State& state_;
const Model& model_{state_.model_};
size_t input_index_{~0U};

std::array<int64_t, 2> shape_{};
Expand Down
45 changes: 21 additions & 24 deletions src/models/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@ std::string ComposeKeyValueName(const std::string& template_string, int index) {

} // namespace

KV_Cache_Combined::KV_Cache_Combined(const Model& model, State& state)
: model_{model},
state_{state},
layer_count_{model.config_->model.decoder.num_hidden_layers},
shape_{2, state_.params_->BatchBeamSize(), model.config_->model.decoder.num_key_value_heads, 0, model.config_->model.decoder.head_size} {
KV_Cache_Combined::KV_Cache_Combined(State& state)
: state_{state},
layer_count_{model_.config_->model.decoder.num_hidden_layers},
shape_{2, state_.params_->BatchBeamSize(), model_.config_->model.decoder.num_key_value_heads, 0, model_.config_->model.decoder.head_size} {
pasts_.resize(layer_count_);
presents_.reserve(layer_count_);

for (int i = 0; i < layer_count_; ++i) {
input_name_strings_.emplace_back(ComposeKeyValueName(model.config_->model.decoder.inputs.past_names, i));
output_name_strings_.emplace_back(ComposeKeyValueName(model.config_->model.decoder.outputs.present_names, i));
input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.past_names, i));
output_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.outputs.present_names, i));
}

// Derive the KV data type from the KV input 0
Expand All @@ -39,7 +38,7 @@ KV_Cache_Combined::KV_Cache_Combined(const Model& model, State& state)
shape_[3] = state_.params_->sequence_length;

for (int i = 0; i < layer_count_; ++i) {
presents_.push_back(OrtValue::CreateTensor(*model.allocator_device_, shape_, type_));
presents_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_));
}
}

Expand Down Expand Up @@ -128,24 +127,23 @@ bool KV_Cache::IsCacheNeeded(const Model& model) {
return model.session_info_->HasInput(ComposeKeyValueName(model.config_->model.decoder.inputs.past_key_names, 0));
}

KV_Cache::KV_Cache(const Model& model, State& state)
: model_{model},
state_{state},
KV_Cache::KV_Cache(State& state)
: state_{state},
layer_count_{model_.config_->model.decoder.num_hidden_layers},
past_present_share_buffer_{state_.params_->search.past_present_share_buffer && (state_.params_->search.num_beams == 1 || model_.config_->model.type == "whisper")},
shape_{state_.params_->BatchBeamSize(), model.config_->model.decoder.num_key_value_heads, 0, model.config_->model.decoder.head_size} {
shape_{state_.params_->BatchBeamSize(), model_.config_->model.decoder.num_key_value_heads, 0, model_.config_->model.decoder.head_size} {
if (g_log.enabled && g_log.warning && past_present_share_buffer_ != state_.params_->search.past_present_share_buffer)
Log("warning", "past_present_share_buffer search option set to true, but has been disabled due to the current configuration. See https://aka.ms/generate_config for details");

pasts_.resize(layer_count_ * 2);
presents_.reserve(layer_count_ * 2);

for (int i = 0; i < layer_count_; ++i) {
input_name_strings_.emplace_back(ComposeKeyValueName(model.config_->model.decoder.inputs.past_key_names, i));
input_name_strings_.emplace_back(ComposeKeyValueName(model.config_->model.decoder.inputs.past_value_names, i));
input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.past_key_names, i));
input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.past_value_names, i));

output_name_strings_.emplace_back(ComposeKeyValueName(model.config_->model.decoder.outputs.present_key_names, i));
output_name_strings_.emplace_back(ComposeKeyValueName(model.config_->model.decoder.outputs.present_value_names, i));
output_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.outputs.present_key_names, i));
output_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.outputs.present_value_names, i));
}

// Derive the KV data type from the KV input 0
Expand Down Expand Up @@ -264,19 +262,18 @@ void KV_Cache::PickPastState(std::span<const int32_t> beam_indices, int index) {
}
}

Cross_Cache::Cross_Cache(const Model& model, State& state)
: model_{model},
state_{state},
Cross_Cache::Cross_Cache(State& state)
: state_{state},
layer_count_{model_.config_->model.decoder.num_hidden_layers},
shape_{state_.params_->BatchBeamSize(), model.config_->model.decoder.num_key_value_heads, 1500, model.config_->model.decoder.head_size} {
shape_{state_.params_->BatchBeamSize(), model_.config_->model.decoder.num_key_value_heads, 1500, model_.config_->model.decoder.head_size} {
values_.reserve(layer_count_ * 2);

for (int i = 0; i < layer_count_; ++i) {
input_name_strings_.emplace_back(ComposeKeyValueName(model.config_->model.decoder.inputs.cross_past_key_names, i));
input_name_strings_.emplace_back(ComposeKeyValueName(model.config_->model.decoder.inputs.cross_past_value_names, i));
input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.cross_past_key_names, i));
input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.cross_past_value_names, i));

output_name_strings_.emplace_back(ComposeKeyValueName(model.config_->model.decoder.outputs.cross_present_key_names, i));
output_name_strings_.emplace_back(ComposeKeyValueName(model.config_->model.decoder.outputs.cross_present_value_names, i));
output_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.outputs.cross_present_key_names, i));
output_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.outputs.cross_present_value_names, i));
}

// Derive the KV data type from the KV input 0
Expand Down
12 changes: 6 additions & 6 deletions src/models/kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace Generators {

struct KV_Cache_Combined {
KV_Cache_Combined(const Model& model, State& state);
KV_Cache_Combined(State& state);

void Add(); // Add to state inputs/outputs
void Update(std::span<const int32_t> beam_indices, int current_length);
Expand All @@ -15,8 +15,8 @@ struct KV_Cache_Combined {
void PickPastState(std::span<const int32_t> beam_indices, int index);

private:
const Model& model_;
State& state_;
const Model& model_{state_.model_};
int layer_count_;
size_t input_index_{~0U}, output_index_{~0U};

Expand All @@ -29,7 +29,7 @@ struct KV_Cache_Combined {
};

struct KV_Cache {
KV_Cache(const Model& model, State& state);
KV_Cache(State& state);

static bool IsCacheNeeded(const Model& model);

Expand All @@ -41,8 +41,8 @@ struct KV_Cache {
void PickPastState(std::span<const int32_t> beam_indices, int index);

private:
const Model& model_;
State& state_;
const Model& model_{state_.model_};
int layer_count_;
size_t input_index_{~0U}, output_index_{~0U};
bool past_present_share_buffer_; // True if model.decoder.past_present_share_buffer is set to true, and we're using cuda, and not beam search
Expand All @@ -58,14 +58,14 @@ struct KV_Cache {

// Very similar to the KV_Cache, but is only created once at the encoder step, then used without modification for every decoder step
struct Cross_Cache {
Cross_Cache(const Model& model, State& state);
Cross_Cache(State& state);

void AddOutputs();
void AddInputs();

private:
const Model& model_;
State& state_;
const Model& model_{state_.model_};
int layer_count_;

std::array<int64_t, 4> shape_;
Expand Down
5 changes: 2 additions & 3 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@

namespace Generators {

Logits::Logits(const Model& model, State& state)
: model_{model},
state_{state},
Logits::Logits(State& state)
: state_{state},
shape_{static_cast<int64_t>(state_.params_->batch_size) * state_.params_->search.num_beams, state_.params_->sequence_length, model_.config_->model.vocab_size},
type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} {
output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
Expand Down
4 changes: 2 additions & 2 deletions src/models/logits.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace Generators {

struct Logits {
Logits(const Model& model, State& state);
Logits(State& state);

void Add();
RoamingArray<float> Get();
Expand All @@ -17,8 +17,8 @@ struct Logits {
private:
void HandleEOSArray(cpu_span<float> logits);

const Model& model_;
State& state_;
const Model& model_{state_.model_};
size_t output_index_{~0U};

std::array<int64_t, 3> shape_{};
Expand Down
4 changes: 2 additions & 2 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ static std::string CurrentModulePath() {
namespace Generators {

State::State(const GeneratorParams& params, const Model& model)
: params_{params.shared_from_this()},
model_{model} {}
: model_{model},
params_{params.shared_from_this()} {}

void State::Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_size) {
if (first_run_) {
Expand Down
2 changes: 1 addition & 1 deletion src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct State {

void ClearIO(); // Clear all inputs/outputs

const Model& model_;
std::shared_ptr<const GeneratorParams> params_;

std::vector<const char*> input_names_, output_names_;
Expand All @@ -49,7 +50,6 @@ struct State {
bool first_run_{true};

private:
const Model& model_;
int current_batch_size_{0};
};

Expand Down
Loading

0 comments on commit 205e997

Please sign in to comment.