Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant config copy from GeneratorParams #920

Merged
merged 3 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/beam_search_scorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters)
: batch_size_{parameters.batch_size},
num_beams_{parameters.search.num_beams},
max_length_{parameters.search.max_length},
pad_token_id_{parameters.pad_token_id},
eos_token_id_{parameters.eos_token_id},
pad_token_id_{parameters.config.model.pad_token_id},
eos_token_id_{parameters.config.model.eos_token_id},
early_stopping_{parameters.search.early_stopping},
not_done_count_{parameters.batch_size} {
size_t const batch_beam_size = static_cast<size_t>(batch_size_) * num_beams_;
Expand Down
4 changes: 2 additions & 2 deletions src/beam_search_scorer_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters)
state_cpu_->batch_size_ = static_cast<size_t>(parameters.batch_size);
state_cpu_->num_beams_ = static_cast<size_t>(parameters.search.num_beams);
state_cpu_->max_length_ = static_cast<size_t>(parameters.search.max_length);
state_cpu_->pad_token_id_ = parameters.pad_token_id;
state_cpu_->eos_token_id_ = parameters.eos_token_id;
state_cpu_->pad_token_id_ = parameters.config.model.pad_token_id;
state_cpu_->eos_token_id_ = parameters.config.model.eos_token_id;
state_cpu_->early_stopping_ = parameters.search.early_stopping;
state_cpu_->not_done_count_ = parameters.batch_size;
state_cpu_->hypothesis_buffer_used_ = 0;
Expand Down
18 changes: 8 additions & 10 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,14 @@ std::string to_string(DeviceType device_type) {
throw std::runtime_error("Unknown device type");
}

GeneratorParams::GeneratorParams(const Config& config) : config{config} {
}

GeneratorParams::GeneratorParams(const Model& model)
: search{model.config_->search},
pad_token_id{model.config_->model.pad_token_id},
eos_token_id{model.config_->model.eos_token_id},
vocab_size{model.config_->model.vocab_size},
hidden_size{model.config_->model.decoder.hidden_size},
: config{*model.config_.get()},
device_type{model.device_type_},
cuda_stream{model.cuda_stream_},
is_cuda_graph_enabled_{IsCudaGraphEnabled(model.config_->model.decoder.session_options)},
config_{model.config_.get()} {
is_cuda_graph_enabled_{IsCudaGraphEnabled(model.config_->model.decoder.session_options)} {
use_cuda_graph = is_cuda_graph_enabled_;
if (use_cuda_graph) {
max_batch_size = 1; // set it to 1 by default
Expand Down Expand Up @@ -107,7 +105,7 @@ void GeneratorParams::SetInputs(const NamedTensors& named_tensors) {
} else {
// If the nominal name is found in the map, use the graph name.
// Else, use the nominal name as the graph name.
[[maybe_unused]] const auto [graph_name, found] = config_->GetGraphName(name);
[[maybe_unused]] const auto [graph_name, found] = config.GetGraphName(name);
extra_inputs.push_back({graph_name, tensor});
}
}
Expand Down Expand Up @@ -139,8 +137,8 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_
throw std::runtime_error("max_length (" + std::to_string(params.search.max_length) + ") cannot be greater than model context_length (" + std::to_string(model.config_->model.context_length) + ")");
if (params.batch_size < 1)
throw std::runtime_error("batch_size must be 1 or greater, is " + std::to_string(params.batch_size));
if (params.vocab_size < 1)
throw std::runtime_error("vocab_size must be 1 or greater, is " + std::to_string(params.vocab_size));
if (params.config.model.vocab_size < 1)
throw std::runtime_error("vocab_size must be 1 or greater, is " + std::to_string(params.config.model.vocab_size));
if (params.sequence_length >= params.search.max_length)
throw std::runtime_error("input sequence_length (" + std::to_string(params.sequence_length) + ") is >= max_length (" + std::to_string(params.search.max_length) + ")");
if (params.input_ids.empty() || params.input_ids.data() == nullptr)
Expand Down
33 changes: 4 additions & 29 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,44 +57,21 @@ enum struct DeviceType {
std::string to_string(DeviceType device_type);

struct GeneratorParams : std::enable_shared_from_this<GeneratorParams>, LeakChecked<GeneratorParams> {
GeneratorParams() = default; // This constructor is only used if doing a custom model handler vs built-in
GeneratorParams(const Config& config); // This constructor is only used for internal generator benchmarks
RyanUnderhill marked this conversation as resolved.
Show resolved Hide resolved
GeneratorParams(const Model& model);

Config::Search search;

// Read only values copied from model
int pad_token_id{};
int eos_token_id{};
int vocab_size{};
int context_length{};
const Config& config; // The model outlives the GeneratorParams
Config::Search search{config.search}; // Copy of the search parameters from the config

int batch_size{1};
int max_batch_size{0};
bool use_cuda_graph{};
int sequence_length{};
int hidden_size{};
int BatchBeamSize() const { return search.num_beams * batch_size; }

DeviceType device_type{DeviceType::CPU};
cudaStream_t cuda_stream{};

#if 0
struct Bert {
std::span<const int32_t> input_ids; // Array of [batchsize][sequence_length]
};

struct Gpt {
using Gpt=Bert;
};

struct T5 {
std::span<const int32_t> encoder_input_ids; // Array of [batchsize][sequence_length]
std::span<const int32_t> decoder_input_ids; // Array of [batchsize][sequence_length]
};
using Bart=T5;

#endif

// TODO: Move this to a separate GPT struct
std::span<const int32_t> input_ids; // Array of [batchsize][sequence_length]

Expand Down Expand Up @@ -123,8 +100,6 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams>, LeakChec

private:
bool is_cuda_graph_enabled_{};
const Config* config_{nullptr}; // Non owning pointer to the config.
// The model outlives the GeneratorParams
};

struct Generator : LeakChecked<Generator> {
Expand Down Expand Up @@ -161,7 +136,7 @@ OrtEnv& GetOrtEnv();

std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path);
std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Model& model);
std::shared_ptr<GeneratorParams> CreateGeneratorParams(); // For benchmarking purposes only
std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Config& config); // For benchmarking purposes only
std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorParams& params);
std::vector<std::vector<int32_t>> Generate(const Model& model, const GeneratorParams& params); // Uses CreateGenerator and a simple loop to return the entire sequence

Expand Down
2 changes: 1 addition & 1 deletion src/models/embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Embeddings::Embeddings(const Model& model, State& state, Embeddings::Mode mode,
: model_{model},
state_{state},
shape_{static_cast<int64_t>(state_.params_->batch_size) * state_.params_->search.num_beams,
state_.params_->sequence_length, state_.params_->hidden_size},
state_.params_->sequence_length, model_.config_->model.decoder.hidden_size},
type_{mode == Embeddings::Mode::Input
? model_.session_info_->GetInputDataType(name)
: model_.session_info_->GetOutputDataType(name)},
Expand Down
2 changes: 1 addition & 1 deletion src/models/image_features.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Generators {
ImageFeatures::ImageFeatures(const Model& model, State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens)
: model_{model},
state_{state},
shape_{num_image_tokens, state_.params_->hidden_size},
shape_{num_image_tokens, model.config_->model.decoder.hidden_size},
type_{mode == ImageFeatures::Mode::Input
? model_.session_info_->GetInputDataType(name)
: model_.session_info_->GetOutputDataType(name)},
Expand Down
2 changes: 1 addition & 1 deletion src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ InputIDs::InputIDs(const Model& model, State& state)
if (state_.params_->BatchBeamSize() != 1) {
throw std::runtime_error("Batch size must be 1 for current_sequence_length and past_sequence_length inputs");
}
const int32_t current_sequence_length = get_unpadded_sequence_length(state_.params_->input_ids, state_.params_->pad_token_id);
const int32_t current_sequence_length = get_unpadded_sequence_length(state_.params_->input_ids, model_.config_->model.pad_token_id);
const std::array<int64_t, 1> current_sequence_length_shape{1};
const std::array<int64_t, 2> past_sequence_length_shape{1, 1};

Expand Down
4 changes: 2 additions & 2 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace Generators {
Logits::Logits(const Model& model, State& state)
: model_{model},
state_{state},
shape_{static_cast<int64_t>(state_.params_->batch_size) * state_.params_->search.num_beams, state_.params_->sequence_length, state_.params_->vocab_size},
shape_{static_cast<int64_t>(state_.params_->batch_size) * state_.params_->search.num_beams, state_.params_->sequence_length, model_.config_->model.vocab_size},
type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} {
output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);

Expand Down Expand Up @@ -70,7 +70,7 @@ RoamingArray<float> Logits::Get() {
// Find the first non pad token from the end
size_t token_index = seq_length;
while (token_index-- > 0) {
if (input_ids[token_index] != state_.params_->pad_token_id)
if (input_ids[token_index] != model_.config_->model.pad_token_id)
break;
}

Expand Down
4 changes: 2 additions & 2 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,8 @@ std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Model& model) {
}

// Used by benchmarking tests only, should not be used normally
std::shared_ptr<GeneratorParams> CreateGeneratorParams() {
return std::make_shared<GeneratorParams>();
std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Config& config) {
return std::make_shared<GeneratorParams>(config);
}

void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<OrtValue>& p_out, DeviceType device_type, cudaStream_t stream) {
Expand Down
2 changes: 1 addition & 1 deletion src/models/position_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ void PositionInputs::InitializeTensors(std::array<int64_t, 2> shape, cpu_span<in
for (int i = 0; i < shape[0]; i++) {
T abs_position = 0;
for (int j = 0; j < shape[1]; j++, word_id++, mask++, position++) {
if (*word_id == state_.params_->pad_token_id) {
if (*word_id == model_.config_->model.pad_token_id) {
*mask = 0;
*position = 0;
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams*
span_sequences.emplace_back(sequences[i]);
}

params.input_ids_owner = Generators::PadInputs(span_sequences, params.pad_token_id);
params.input_ids_owner = Generators::PadInputs(span_sequences, params.config.model.pad_token_id);
params.batch_size = static_cast<int>(sequences.size());
params.sequence_length = static_cast<int>(params.input_ids_owner.size() / params.batch_size);
params.input_ids = params.input_ids_owner;
Expand Down
7 changes: 4 additions & 3 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,10 @@ PYBIND11_MODULE(onnxruntime_genai, m) {

pybind11::class_<PyGeneratorParams>(m, "GeneratorParams")
.def(pybind11::init<const Model&>())
.def_property_readonly("pad_token_id", [](const PyGeneratorParams& v) { return v.params_->pad_token_id; })
.def_property_readonly("eos_token_id", [](const PyGeneratorParams& v) { return v.params_->eos_token_id; })
.def_property_readonly("vocab_size", [](const PyGeneratorParams& v) { return v.params_->vocab_size; })
// TODO(ryanhill): Remove these entirely or replace with a single property that returns the entire config?
.def_property_readonly("pad_token_id", [](const PyGeneratorParams& v) { return v.params_->config.model.pad_token_id; })
.def_property_readonly("eos_token_id", [](const PyGeneratorParams& v) { return v.params_->config.model.eos_token_id; })
.def_property_readonly("vocab_size", [](const PyGeneratorParams& v) { return v.params_->config.model.vocab_size; })
.def_readwrite("input_ids", &PyGeneratorParams::py_input_ids_)
// TODO(baijumeswani): Rename/redesign the whisper_input_features to be more generic
.def_readwrite("whisper_input_features", &PyGeneratorParams::py_whisper_input_features_)
Expand Down
26 changes: 13 additions & 13 deletions src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ int Search_Cpu::GetSequenceLength() const {
void BeamSearch_Cpu::SelectTop() {
// Normalize next token scores
for (int i = 0; i < params_->BatchBeamSize(); i++) {
std::span<float> const scores = next_token_scores_.subspan(static_cast<size_t>(i) * static_cast<size_t>(params_->vocab_size), params_->vocab_size);
std::span<float> const scores = next_token_scores_.subspan(static_cast<size_t>(i) * static_cast<size_t>(params_->config.model.vocab_size), params_->config.model.vocab_size);
LogSoftMax(scores, 1.0);
}

Expand All @@ -77,7 +77,7 @@ void BeamSearch_Cpu::SelectTop() {
int batch_beam_index = 0;
for (int i = 0; i < params_->batch_size; i++) {
for (int j = 0; j < params_->search.num_beams; j++, batch_beam_index++) {
for (int k = 0; k < params_->vocab_size; k++, offset++) {
for (int k = 0; k < params_->config.model.vocab_size; k++, offset++) {
next_token_scores_[offset] += beam_scores[batch_beam_index];
}
}
Expand All @@ -103,7 +103,7 @@ void BeamSearch_Cpu::SelectTop() {
// TODO(aciddelgado): Optimize this top k with partial sort
for (size_t batch_index = 0; batch_index < static_cast<size_t>(params_->batch_size); batch_index++) {
std::priority_queue<ScoreIndex, std::vector<ScoreIndex>> queue;
auto token_scores_sub = next_token_scores_.subspan(batch_index * params_->search.num_beams * params_->vocab_size, static_cast<size_t>(params_->search.num_beams) * params_->vocab_size);
auto token_scores_sub = next_token_scores_.subspan(batch_index * params_->search.num_beams * params_->config.model.vocab_size, static_cast<size_t>(params_->search.num_beams) * params_->config.model.vocab_size);
for (int i = 0; i < token_scores_sub.size(); i++) {
queue.push({token_scores_sub[i], i});
}
Expand All @@ -113,8 +113,8 @@ void BeamSearch_Cpu::SelectTop() {
auto next_scores_sub = next_scores.subspan(top_k * batch_index, top_k);
for (unsigned i = 0; i < top_k; i++) {
auto v = queue.top();
next_indices_sub[i] = v.index / params_->vocab_size;
next_tokens_sub[i] = v.index % params_->vocab_size;
next_indices_sub[i] = v.index / params_->config.model.vocab_size;
next_tokens_sub[i] = v.index % params_->config.model.vocab_size;
next_scores_sub[i] = v.score;
queue.pop();
}
Expand All @@ -139,7 +139,7 @@ void GreedySearch_Cpu::SelectTop() {
continue;
}

std::span<float> const scores = next_token_scores_.subspan(batch_id * params_->vocab_size, params_->vocab_size);
std::span<float> const scores = next_token_scores_.subspan(batch_id * params_->config.model.vocab_size, params_->config.model.vocab_size);
auto const token = static_cast<int32_t>(std::distance(scores.begin(), std::max_element(scores.begin(), scores.end())));
SetNextToken(batch_id, token);
}
Expand All @@ -149,7 +149,7 @@ void GreedySearch_Cpu::SelectTop() {

void GreedySearch_Cpu::SampleTopK(int k, float temperature) {
for (size_t batch_id = 0; batch_id < params_->batch_size; batch_id++) {
std::span<float> const scores = next_token_scores_.subspan(batch_id * params_->vocab_size, params_->vocab_size);
std::span<float> const scores = next_token_scores_.subspan(batch_id * params_->config.model.vocab_size, params_->config.model.vocab_size);
SoftMax(scores, temperature);
// Find the top K scores
std::vector<int> indices(scores.size());
Expand All @@ -168,7 +168,7 @@ void GreedySearch_Cpu::SampleTopP(float p, float temperature) {
if (PadIfAlreadyEOS(batch_id)) {
continue;
}
std::span<float> const scores = next_token_scores_.subspan(batch_id * params_->vocab_size, params_->vocab_size);
std::span<float> const scores = next_token_scores_.subspan(batch_id * params_->config.model.vocab_size, params_->config.model.vocab_size);
SoftMax(scores, temperature);
// Sort an array of indices into the scores
std::vector<int32_t> indices(scores.size());
Expand Down Expand Up @@ -197,7 +197,7 @@ void GreedySearch_Cpu::SampleTopKTopP(int k, float p, float temperature) {
if (PadIfAlreadyEOS(batch_id)) {
continue;
}
std::span<float> const scores = next_token_scores_.subspan(batch_id * params_->vocab_size, params_->vocab_size);
std::span<float> const scores = next_token_scores_.subspan(batch_id * params_->config.model.vocab_size, params_->config.model.vocab_size);
SoftMax(scores, temperature);
// Find the top K scores
std::vector<int> indices(scores.size());
Expand Down Expand Up @@ -226,13 +226,13 @@ bool GreedySearch_Cpu::PadIfAlreadyEOS(size_t batch_id) {
return false;
}

next_tokens_[batch_id] = params_->pad_token_id;
next_tokens_[batch_id] = params_->config.model.pad_token_id;
return true;
}

void GreedySearch_Cpu::SetNextToken(size_t batch_id, int32_t token) {
next_tokens_[batch_id] = token;
if (token == params_->eos_token_id) {
if (token == params_->config.model.eos_token_id) {
eos_seen_[batch_id] = true;
if (g_log.enabled && g_log.hit_eos)
Log("hit_eos", "EOS seen on batch " + std::to_string(batch_id));
Expand Down Expand Up @@ -295,7 +295,7 @@ RoamingArray<int32_t> BeamSearch_Cpu::GetSequence(size_t batch_id, size_t beam_i

std::span<float> Search_Cpu::GetScores(int batch_beam_index) const {
assert(batch_beam_index >= 0 && batch_beam_index < params_->BatchBeamSize());
return next_token_scores_.subspan(static_cast<size_t>(batch_beam_index) * params_->vocab_size, params_->vocab_size);
return next_token_scores_.subspan(static_cast<size_t>(batch_beam_index) * params_->config.model.vocab_size, params_->config.model.vocab_size);
}

void Search_Cpu::ApplyMinLength(int min_length) {
Expand All @@ -306,7 +306,7 @@ void Search_Cpu::ApplyMinLength(int min_length) {
const int batch_beam_size = params_->BatchBeamSize();
for (int i = 0; i < batch_beam_size; i++) {
std::span<float> const beam_token_scores = GetScores(i);
beam_token_scores[params_->eos_token_id] = std::numeric_limits<float>::lowest();
beam_token_scores[params_->config.model.eos_token_id] = std::numeric_limits<float>::lowest();
}
}

Expand Down
Loading
Loading