diff --git a/src/beam_search_scorer.cpp b/src/beam_search_scorer.cpp index ec9760377..033118285 100644 --- a/src/beam_search_scorer.cpp +++ b/src/beam_search_scorer.cpp @@ -47,8 +47,8 @@ BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters) : batch_size_{parameters.batch_size}, num_beams_{parameters.search.num_beams}, max_length_{parameters.search.max_length}, - pad_token_id_{parameters.pad_token_id}, - eos_token_id_{parameters.eos_token_id}, + pad_token_id_{parameters.config.model.pad_token_id}, + eos_token_id_{parameters.config.model.eos_token_id}, early_stopping_{parameters.search.early_stopping}, not_done_count_{parameters.batch_size} { size_t const batch_beam_size = static_cast(batch_size_) * num_beams_; diff --git a/src/beam_search_scorer_cuda.cpp b/src/beam_search_scorer_cuda.cpp index 4c48ed82a..c61b69111 100644 --- a/src/beam_search_scorer_cuda.cpp +++ b/src/beam_search_scorer_cuda.cpp @@ -12,8 +12,8 @@ BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters) state_cpu_->batch_size_ = static_cast(parameters.batch_size); state_cpu_->num_beams_ = static_cast(parameters.search.num_beams); state_cpu_->max_length_ = static_cast(parameters.search.max_length); - state_cpu_->pad_token_id_ = parameters.pad_token_id; - state_cpu_->eos_token_id_ = parameters.eos_token_id; + state_cpu_->pad_token_id_ = parameters.config.model.pad_token_id; + state_cpu_->eos_token_id_ = parameters.config.model.eos_token_id; state_cpu_->early_stopping_ = parameters.search.early_stopping; state_cpu_->not_done_count_ = parameters.batch_size; state_cpu_->hypothesis_buffer_used_ = 0; diff --git a/src/generators.cpp b/src/generators.cpp index f0cb710c7..4a2a08b78 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -64,16 +64,14 @@ std::string to_string(DeviceType device_type) { throw std::runtime_error("Unknown device type"); } +GeneratorParams::GeneratorParams(const Config& config) : config{config} { +} + GeneratorParams::GeneratorParams(const Model& model) - : search{model.config_->search}, - pad_token_id{model.config_->model.pad_token_id}, - eos_token_id{model.config_->model.eos_token_id}, - vocab_size{model.config_->model.vocab_size}, - hidden_size{model.config_->model.decoder.hidden_size}, + : config{*model.config_.get()}, device_type{model.device_type_}, cuda_stream{model.cuda_stream_}, - is_cuda_graph_enabled_{IsCudaGraphEnabled(model.config_->model.decoder.session_options)}, - config_{model.config_.get()} { + is_cuda_graph_enabled_{IsCudaGraphEnabled(model.config_->model.decoder.session_options)} { use_cuda_graph = is_cuda_graph_enabled_; if (use_cuda_graph) { max_batch_size = 1; // set it to 1 by default @@ -107,7 +105,7 @@ void GeneratorParams::SetInputs(const NamedTensors& named_tensors) { } else { // If the nominal name is found in the map, use the graph name. // Else, use the nominal name as the graph name. - [[maybe_unused]] const auto [graph_name, found] = config_->GetGraphName(name); + [[maybe_unused]] const auto [graph_name, found] = config.GetGraphName(name); extra_inputs.push_back({graph_name, tensor}); } } @@ -139,8 +137,8 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ throw std::runtime_error("max_length (" + std::to_string(params.search.max_length) + ") cannot be greater than model context_length (" + std::to_string(model.config_->model.context_length) + ")"); if (params.batch_size < 1) throw std::runtime_error("batch_size must be 1 or greater, is " + std::to_string(params.batch_size)); - if (params.vocab_size < 1) - throw std::runtime_error("vocab_size must be 1 or greater, is " + std::to_string(params.vocab_size)); + if (params.config.model.vocab_size < 1) + throw std::runtime_error("vocab_size must be 1 or greater, is " + std::to_string(params.config.model.vocab_size)); if (params.sequence_length >= params.search.max_length) throw std::runtime_error("input sequence_length (" + std::to_string(params.sequence_length) + ") is >= max_length (" + std::to_string(params.search.max_length) + ")"); if (params.input_ids.empty() || params.input_ids.data() == nullptr) diff --git a/src/generators.h b/src/generators.h index b585c0b8b..488dd8fa9 100644 --- a/src/generators.h +++ b/src/generators.h @@ -57,44 +57,21 @@ enum struct DeviceType { std::string to_string(DeviceType device_type); struct GeneratorParams : std::enable_shared_from_this, LeakChecked { - GeneratorParams() = default; // This constructor is only used if doing a custom model handler vs built-in + GeneratorParams(const Config& config); // This constructor is only used for internal generator benchmarks GeneratorParams(const Model& model); - Config::Search search; - - // Read only values copied from model - int pad_token_id{}; - int eos_token_id{}; - int vocab_size{}; - int context_length{}; + const Config& config; // The model outlives the GeneratorParams + Config::Search search{config.search}; // Copy of the search parameters from the config int batch_size{1}; int max_batch_size{0}; bool use_cuda_graph{}; int sequence_length{}; - int hidden_size{}; int BatchBeamSize() const { return search.num_beams * batch_size; } DeviceType device_type{DeviceType::CPU}; cudaStream_t cuda_stream{}; -#if 0 - struct Bert { - std::span input_ids; // Array of [batchsize][sequence_length] - }; - - struct Gpt { - using Gpt=Bert; - }; - - struct T5 { - std::span encoder_input_ids; // Array of [batchsize][sequence_length] - std::span decoder_input_ids; // Array of [batchsize][sequence_length] - }; - using Bart=T5; - -#endif - // TODO: Move this to a separate GPT struct std::span input_ids; // Array of [batchsize][sequence_length] @@ -123,8 +100,6 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec private: bool is_cuda_graph_enabled_{}; - const Config* config_{nullptr}; // Non owning pointer to the config. - // The model outlives the GeneratorParams }; struct Generator : LeakChecked { @@ -161,7 +136,7 @@ OrtEnv& GetOrtEnv(); std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path); std::shared_ptr CreateGeneratorParams(const Model& model); -std::shared_ptr CreateGeneratorParams(); // For benchmarking purposes only +std::shared_ptr CreateGeneratorParams(const Config& config); // For benchmarking purposes only std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params); std::vector> Generate(const Model& model, const GeneratorParams& params); // Uses CreateGenerator and a simple loop to return the entire sequence diff --git a/src/models/embeddings.cpp b/src/models/embeddings.cpp index 10508b89f..cba33cea7 100644 --- a/src/models/embeddings.cpp +++ b/src/models/embeddings.cpp @@ -11,7 +11,7 @@ Embeddings::Embeddings(const Model& model, State& state, Embeddings::Mode mode, : model_{model}, state_{state}, shape_{static_cast(state_.params_->batch_size) * state_.params_->search.num_beams, - state_.params_->sequence_length, state_.params_->hidden_size}, + state_.params_->sequence_length, model_.config_->model.decoder.hidden_size}, type_{mode == Embeddings::Mode::Input ? model_.session_info_->GetInputDataType(name) : model_.session_info_->GetOutputDataType(name)}, diff --git a/src/models/image_features.cpp b/src/models/image_features.cpp index b5693a10a..f9d61649d 100644 --- a/src/models/image_features.cpp +++ b/src/models/image_features.cpp @@ -9,7 +9,7 @@ namespace Generators { ImageFeatures::ImageFeatures(const Model& model, State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens) : model_{model}, state_{state}, - shape_{num_image_tokens, state_.params_->hidden_size}, + shape_{num_image_tokens, model.config_->model.decoder.hidden_size}, type_{mode == ImageFeatures::Mode::Input ? model_.session_info_->GetInputDataType(name) : model_.session_info_->GetOutputDataType(name)}, diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 2c80aad2f..0ced666b5 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -55,7 +55,7 @@ InputIDs::InputIDs(const Model& model, State& state) if (state_.params_->BatchBeamSize() != 1) { throw std::runtime_error("Batch size must be 1 for current_sequence_length and past_sequence_length inputs"); } - const int32_t current_sequence_length = get_unpadded_sequence_length(state_.params_->input_ids, state_.params_->pad_token_id); + const int32_t current_sequence_length = get_unpadded_sequence_length(state_.params_->input_ids, model_.config_->model.pad_token_id); const std::array current_sequence_length_shape{1}; const std::array past_sequence_length_shape{1, 1}; diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 5d371a74c..6151402b6 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -12,7 +12,7 @@ namespace Generators { Logits::Logits(const Model& model, State& state) : model_{model}, state_{state}, - shape_{static_cast(state_.params_->batch_size) * state_.params_->search.num_beams, state_.params_->sequence_length, state_.params_->vocab_size}, + shape_{static_cast(state_.params_->batch_size) * state_.params_->search.num_beams, state_.params_->sequence_length, model_.config_->model.vocab_size}, type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} { output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); @@ -70,7 +70,7 @@ RoamingArray Logits::Get() { // Find the first non pad token from the end size_t token_index = seq_length; while (token_index-- > 0) { - if (input_ids[token_index] != state_.params_->pad_token_id) + if (input_ids[token_index] != model_.config_->model.pad_token_id) break; } diff --git a/src/models/model.cpp b/src/models/model.cpp index a83d7c2a6..780566937 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -508,8 +508,8 @@ std::shared_ptr CreateGeneratorParams(const Model& model) { } // Used by benchmarking tests only, should not be used normally -std::shared_ptr CreateGeneratorParams() { - return std::make_shared(); +std::shared_ptr CreateGeneratorParams(const Config& config) { + return std::make_shared(config); } void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr& p_out, DeviceType device_type, cudaStream_t stream) { diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index 91671aedc..2666afc17 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -334,7 +334,7 @@ void PositionInputs::InitializeTensors(std::array shape, cpu_spanpad_token_id) { + if (*word_id == model_.config_->model.pad_token_id) { *mask = 0; *position = 0; } else { diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index fdb292258..4ce580fad 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -196,7 +196,7 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams* span_sequences.emplace_back(sequences[i]); } - params.input_ids_owner = Generators::PadInputs(span_sequences, params.pad_token_id); + params.input_ids_owner = Generators::PadInputs(span_sequences, params.config.model.pad_token_id); params.batch_size = static_cast(sequences.size()); params.sequence_length = static_cast(params.input_ids_owner.size() / params.batch_size); params.input_ids = params.input_ids_owner; diff --git a/src/python/python.cpp b/src/python/python.cpp index f72a7c654..ffbbc5d36 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -375,9 +375,10 @@ PYBIND11_MODULE(onnxruntime_genai, m) { pybind11::class_(m, "GeneratorParams") .def(pybind11::init()) - .def_property_readonly("pad_token_id", [](const PyGeneratorParams& v) { return v.params_->pad_token_id; }) - .def_property_readonly("eos_token_id", [](const PyGeneratorParams& v) { return v.params_->eos_token_id; }) - .def_property_readonly("vocab_size", [](const PyGeneratorParams& v) { return v.params_->vocab_size; }) + // TODO(ryanhill): Remove these entirely or replace with a single property that returns the entire config? + .def_property_readonly("pad_token_id", [](const PyGeneratorParams& v) { return v.params_->config.model.pad_token_id; }) + .def_property_readonly("eos_token_id", [](const PyGeneratorParams& v) { return v.params_->config.model.eos_token_id; }) + .def_property_readonly("vocab_size", [](const PyGeneratorParams& v) { return v.params_->config.model.vocab_size; }) .def_readwrite("input_ids", &PyGeneratorParams::py_input_ids_) // TODO(baijumeswani): Rename/redesign the whisper_input_features to be more generic .def_readwrite("whisper_input_features", &PyGeneratorParams::py_whisper_input_features_) diff --git a/src/search.cpp b/src/search.cpp index d7a9d3c69..6dc938dc4 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -64,7 +64,7 @@ int Search_Cpu::GetSequenceLength() const { void BeamSearch_Cpu::SelectTop() { // Normalize next token scores for (int i = 0; i < params_->BatchBeamSize(); i++) { - std::span const scores = next_token_scores_.subspan(static_cast(i) * static_cast(params_->vocab_size), params_->vocab_size); + std::span const scores = next_token_scores_.subspan(static_cast(i) * static_cast(params_->config.model.vocab_size), params_->config.model.vocab_size); LogSoftMax(scores, 1.0); } @@ -77,7 +77,7 @@ void BeamSearch_Cpu::SelectTop() { int batch_beam_index = 0; for (int i = 0; i < params_->batch_size; i++) { for (int j = 0; j < params_->search.num_beams; j++, batch_beam_index++) { - for (int k = 0; k < params_->vocab_size; k++, offset++) { + for (int k = 0; k < params_->config.model.vocab_size; k++, offset++) { next_token_scores_[offset] += beam_scores[batch_beam_index]; } } @@ -103,7 +103,7 @@ void BeamSearch_Cpu::SelectTop() { // TODO(aciddelgado): Optimize this top k with partial sort for (size_t batch_index = 0; batch_index < static_cast(params_->batch_size); batch_index++) { std::priority_queue> queue; - auto token_scores_sub = next_token_scores_.subspan(batch_index * params_->search.num_beams * params_->vocab_size, static_cast(params_->search.num_beams) * params_->vocab_size); + auto token_scores_sub = next_token_scores_.subspan(batch_index * params_->search.num_beams * params_->config.model.vocab_size, static_cast(params_->search.num_beams) * params_->config.model.vocab_size); for (int i = 0; i < token_scores_sub.size(); i++) { queue.push({token_scores_sub[i], i}); } @@ -113,8 +113,8 @@ void BeamSearch_Cpu::SelectTop() { auto next_scores_sub = next_scores.subspan(top_k * batch_index, top_k); for (unsigned i = 0; i < top_k; i++) { auto v = queue.top(); - next_indices_sub[i] = v.index / params_->vocab_size; - next_tokens_sub[i] = v.index % params_->vocab_size; + next_indices_sub[i] = v.index / params_->config.model.vocab_size; + next_tokens_sub[i] = v.index % params_->config.model.vocab_size; next_scores_sub[i] = v.score; queue.pop(); } @@ -139,7 +139,7 @@ void GreedySearch_Cpu::SelectTop() { continue; } - std::span const scores = next_token_scores_.subspan(batch_id * params_->vocab_size, params_->vocab_size); + std::span const scores = next_token_scores_.subspan(batch_id * params_->config.model.vocab_size, params_->config.model.vocab_size); auto const token = static_cast(std::distance(scores.begin(), std::max_element(scores.begin(), scores.end()))); SetNextToken(batch_id, token); } @@ -149,7 +149,7 @@ void GreedySearch_Cpu::SelectTop() { void GreedySearch_Cpu::SampleTopK(int k, float temperature) { for (size_t batch_id = 0; batch_id < params_->batch_size; batch_id++) { - std::span const scores = next_token_scores_.subspan(batch_id * params_->vocab_size, params_->vocab_size); + std::span const scores = next_token_scores_.subspan(batch_id * params_->config.model.vocab_size, params_->config.model.vocab_size); SoftMax(scores, temperature); // Find the top K scores std::vector indices(scores.size()); @@ -168,7 +168,7 @@ void GreedySearch_Cpu::SampleTopP(float p, float temperature) { if (PadIfAlreadyEOS(batch_id)) { continue; } - std::span const scores = next_token_scores_.subspan(batch_id * params_->vocab_size, params_->vocab_size); + std::span const scores = next_token_scores_.subspan(batch_id * params_->config.model.vocab_size, params_->config.model.vocab_size); SoftMax(scores, temperature); // Sort an array of indices into the scores std::vector indices(scores.size()); @@ -197,7 +197,7 @@ void GreedySearch_Cpu::SampleTopKTopP(int k, float p, float temperature) { if (PadIfAlreadyEOS(batch_id)) { continue; } - std::span const scores = next_token_scores_.subspan(batch_id * params_->vocab_size, params_->vocab_size); + std::span const scores = next_token_scores_.subspan(batch_id * params_->config.model.vocab_size, params_->config.model.vocab_size); SoftMax(scores, temperature); // Find the top K scores std::vector indices(scores.size()); @@ -226,13 +226,13 @@ bool GreedySearch_Cpu::PadIfAlreadyEOS(size_t batch_id) { return false; } - next_tokens_[batch_id] = params_->pad_token_id; + next_tokens_[batch_id] = params_->config.model.pad_token_id; return true; } void GreedySearch_Cpu::SetNextToken(size_t batch_id, int32_t token) { next_tokens_[batch_id] = token; - if (token == params_->eos_token_id) { + if (token == params_->config.model.eos_token_id) { eos_seen_[batch_id] = true; if (g_log.enabled && g_log.hit_eos) Log("hit_eos", "EOS seen on batch " + std::to_string(batch_id)); @@ -295,7 +295,7 @@ RoamingArray BeamSearch_Cpu::GetSequence(size_t batch_id, size_t beam_i std::span Search_Cpu::GetScores(int batch_beam_index) const { assert(batch_beam_index >= 0 && batch_beam_index < params_->BatchBeamSize()); - return next_token_scores_.subspan(static_cast(batch_beam_index) * params_->vocab_size, params_->vocab_size); + return next_token_scores_.subspan(static_cast(batch_beam_index) * params_->config.model.vocab_size, params_->config.model.vocab_size); } void Search_Cpu::ApplyMinLength(int min_length) { @@ -306,7 +306,7 @@ void Search_Cpu::ApplyMinLength(int min_length) { const int batch_beam_size = params_->BatchBeamSize(); for (int i = 0; i < batch_beam_size; i++) { std::span const beam_token_scores = GetScores(i); - beam_token_scores[params_->eos_token_id] = std::numeric_limits::lowest(); + beam_token_scores[params_->config.model.eos_token_id] = std::numeric_limits::lowest(); } } diff --git a/src/search_cuda.cpp b/src/search_cuda.cpp index 160d102b1..23a232bb3 100644 --- a/src/search_cuda.cpp +++ b/src/search_cuda.cpp @@ -39,7 +39,7 @@ GreedySearch_Cuda::GreedySearch_Cuda(const GeneratorParams& params) random_seed = params_->search.random_seed; else random_seed = std::random_device{}(); - samplingdata_ = std::make_unique(random_seed, params_->batch_size, params_->vocab_size, params_->cuda_stream); + samplingdata_ = std::make_unique(random_seed, params_->batch_size, params_->config.model.vocab_size, params_->cuda_stream); } BeamSearch_Cuda::BeamSearch_Cuda(const GeneratorParams& params) @@ -51,7 +51,7 @@ BeamSearch_Cuda::BeamSearch_Cuda(const GeneratorParams& params) topk_next_tokens_ = CudaMallocArray(2 * batch_beam_size); topk_next_indices_ = CudaMallocArray(2 * batch_beam_size); topk_next_scores_ = CudaMallocArray(2 * batch_beam_size); - softmax_buffer_ = CudaMallocArray(batch_beam_size * params_->vocab_size); + softmax_buffer_ = CudaMallocArray(batch_beam_size * params_->config.model.vocab_size); constexpr size_t max_parts_of_vocab = 128; size_t topk_buffer_size = batch_beam_size * (max_parts_of_vocab + 1) * params_->search.num_beams * 2 * 2; @@ -84,12 +84,12 @@ int Search_Cuda::GetSequenceLength() const { } void BeamSearch_Cuda::SelectTop() { - cuda::DispatchBlockwiseSoftmaxForward(const_cast(¶ms_->cuda_stream), softmax_buffer_.get(), next_token_scores_.data(), params_->vocab_size, - params_->vocab_size, params_->vocab_size, params_->BatchBeamSize()); + cuda::DispatchBlockwiseSoftmaxForward(const_cast(¶ms_->cuda_stream), softmax_buffer_.get(), next_token_scores_.data(), params_->config.model.vocab_size, + params_->config.model.vocab_size, params_->config.model.vocab_size, params_->BatchBeamSize()); // Copy next_token_scores to CPU - auto next_token_scores_cpu = CudaMallocHostArray(params_->BatchBeamSize() * params_->vocab_size); - cudaMemcpyAsync(next_token_scores_cpu.get(), softmax_buffer_.get(), params_->BatchBeamSize() * params_->vocab_size * sizeof(float), cudaMemcpyDeviceToHost, params_->cuda_stream); + auto next_token_scores_cpu = CudaMallocHostArray(params_->BatchBeamSize() * params_->config.model.vocab_size); + cudaMemcpyAsync(next_token_scores_cpu.get(), softmax_buffer_.get(), params_->BatchBeamSize() * params_->config.model.vocab_size * sizeof(float), cudaMemcpyDeviceToHost, params_->cuda_stream); CudaCheck() == cudaStreamSynchronize(params_->cuda_stream); auto beam_scores = beam_scorer_->GetNextScores(); @@ -97,7 +97,7 @@ void BeamSearch_Cuda::SelectTop() { // Add beam score to next token scores. Corresponding python code is like: // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) cuda::LaunchAddProbsKernel(softmax_buffer_.get(), beam_scores.data(), - params_->batch_size, params_->search.num_beams, params_->vocab_size, params_->cuda_stream); + params_->batch_size, params_->search.num_beams, params_->config.model.vocab_size, params_->cuda_stream); if (params_->search.num_beams <= 32) { constexpr size_t max_parts_of_vocab = 128; @@ -111,7 +111,7 @@ void BeamSearch_Cuda::SelectTop() { cuda::BeamSearchTopK(softmax_buffer_.get(), params_->batch_size, params_->search.num_beams, - params_->vocab_size, + params_->config.model.vocab_size, 2 * params_->search.num_beams, topk_scores_1st_stage, topk_tokens_1st_stage, @@ -145,7 +145,7 @@ void BeamSearch_Cuda::SelectTop() { } void GreedySearch_Cuda::SelectTop() { - std::span scores = next_token_scores_.subspan(0, params_->batch_size * params_->vocab_size); + std::span scores = next_token_scores_.subspan(0, params_->batch_size * params_->config.model.vocab_size); cuda::GetSample(samplingdata_.get(), params_->cuda_stream, next_tokens_.data(), scores.data(), int(scores.size() / params_->batch_size), params_->batch_size, 1, 0.0, 1.0); CheckForEOS(); @@ -153,7 +153,7 @@ void GreedySearch_Cuda::SelectTop() { } void GreedySearch_Cuda::SampleTopP(float p, float temperature) { - std::span scores = next_token_scores_.subspan(0, params_->batch_size * params_->vocab_size); + std::span scores = next_token_scores_.subspan(0, params_->batch_size * params_->config.model.vocab_size); cuda::GetSample(samplingdata_.get(), params_->cuda_stream, next_tokens_.data(), scores.data(), int(scores.size() / params_->batch_size), params_->batch_size, -1, p, temperature); CheckForEOS(); @@ -161,7 +161,7 @@ void GreedySearch_Cuda::SampleTopP(float p, float temperature) { } void GreedySearch_Cuda::SampleTopK(int k, float temperature) { - std::span scores = next_token_scores_.subspan(0, params_->batch_size * params_->vocab_size); + std::span scores = next_token_scores_.subspan(0, params_->batch_size * params_->config.model.vocab_size); cuda::GetSample(samplingdata_.get(), params_->cuda_stream, next_tokens_.data(), scores.data(), int(scores.size() / params_->batch_size), params_->batch_size, k, 0.0, temperature); CheckForEOS(); @@ -169,7 +169,7 @@ void GreedySearch_Cuda::SampleTopK(int k, float temperature) { } void GreedySearch_Cuda::SampleTopKTopP(int k, float p, float temperature) { - std::span scores = next_token_scores_.subspan(0, params_->batch_size * params_->vocab_size); + std::span scores = next_token_scores_.subspan(0, params_->batch_size * params_->config.model.vocab_size); cuda::GetSample(samplingdata_.get(), params_->cuda_stream, next_tokens_.data(), scores.data(), int(scores.size() / params_->batch_size), params_->batch_size, k, p, temperature); CheckForEOS(); @@ -178,7 +178,7 @@ void GreedySearch_Cuda::SampleTopKTopP(int k, float p, float temperature) { void GreedySearch_Cuda::CheckForEOS() { assert(next_tokens_.size() == eos_meet_.size()); - cuda::Launch_CheckForEOS(next_tokens_.data(), static_cast(next_tokens_.size()), eos_meet_.data(), params_->eos_token_id, params_->pad_token_id, done_cpu_.get(), params_->cuda_stream); + cuda::Launch_CheckForEOS(next_tokens_.data(), static_cast(next_tokens_.size()), eos_meet_.data(), params_->config.model.eos_token_id, params_->config.model.pad_token_id, done_cpu_.get(), params_->cuda_stream); } void GreedySearch_Cuda::AppendNextTokensToSequences() { @@ -246,7 +246,7 @@ void GreedySearch::Finalize(size_t num_return_sequences, std::span outp std::span Search_Cuda::GetScores(int batch_beam_index) { assert(batch_beam_index >= 0 && batch_beam_index < params_->BatchBeamSize()); - return next_token_scores_.subspan(batch_beam_index * params_->vocab_size, params_->vocab_size); + return next_token_scores_.subspan(batch_beam_index * params_->config.model.vocab_size, params_->config.model.vocab_size); } std::span Search_Cuda::GetScores() { @@ -257,7 +257,7 @@ void Search_Cuda::ApplyMinLength(int min_length) { if (sequences_.GetSequenceLength() >= min_length) return; - cuda::LaunchSetScoreProcessor(GetScores().data(), params_->BatchBeamSize(), params_->vocab_size, params_->eos_token_id, std::numeric_limits::lowest(), params_->cuda_stream); + cuda::LaunchSetScoreProcessor(GetScores().data(), params_->BatchBeamSize(), params_->config.model.vocab_size, params_->config.model.eos_token_id, std::numeric_limits::lowest(), params_->cuda_stream); } void Search_Cuda::ApplyRepetitionPenalty(float penalty) { @@ -265,7 +265,7 @@ void Search_Cuda::ApplyRepetitionPenalty(float penalty) { return; cuda::LaunchRepetitionPenaltyProcessor(sequences_.GetSequences().data(), - GetScores().data(), params_->batch_size, params_->search.num_beams, params_->vocab_size, + GetScores().data(), params_->batch_size, params_->search.num_beams, params_->config.model.vocab_size, params_->search.max_length, GetSequenceLength(), penalty, params_->cuda_stream); } diff --git a/test/sampling_benchmark.cpp b/test/sampling_benchmark.cpp index e614b2b20..59cfa61c0 100644 --- a/test/sampling_benchmark.cpp +++ b/test/sampling_benchmark.cpp @@ -19,17 +19,19 @@ void CreateRandomLogits(float* logits, int num_large, int vocab_size, int batch_ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPCpu) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int vocab_size = 32000; // vocab size of llama int batch_size = 1; std::vector input_ids{0, 1, 2, 3, 4}; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 32000; // vocab size of llama + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; - std::vector logits_cpu(vocab_size * batch_size); + std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution<> dist(1, 25); @@ -38,8 +40,8 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPCpu) { for (int i = 0; i < num_iter; i++) { auto generator = Generators::CreateGenerator(*model, *params); int num_large = dist(engine); - CreateRandomLogits(logits_cpu.data(), num_large, vocab_size, batch_size, engine); - generator->search_->SetLogits(Generators::cpu_span(logits_cpu.data(), vocab_size * batch_size)); + CreateRandomLogits(logits_cpu.data(), num_large, config.model.vocab_size, batch_size, engine); + generator->search_->SetLogits(Generators::cpu_span(logits_cpu.data(), config.model.vocab_size * batch_size)); auto start = std::chrono::high_resolution_clock::now(); generator->search_->SampleTopP(0.95f, 1.0f); auto stop = std::chrono::high_resolution_clock::now(); @@ -53,18 +55,20 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPCpu) { TEST(Benchmarks, BenchmarkRandomizedSamplingTopKCpu) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int vocab_size = 32000; // vocab size of llama int batch_size = 1; int k = 5; std::vector input_ids{0, 1, 2, 3, 4}; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 32000; // vocab size of llama + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; - std::vector logits_cpu(vocab_size * batch_size); + std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution<> dist(5, 25); @@ -73,8 +77,8 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopKCpu) { for (int i = 0; i < num_iter; i++) { int num_large = dist(engine); auto generator = Generators::CreateGenerator(*model, *params); - CreateRandomLogits(logits_cpu.data(), num_large, vocab_size, batch_size, engine); - generator->search_->SetLogits(Generators::cpu_span(logits_cpu.data(), vocab_size * batch_size)); + CreateRandomLogits(logits_cpu.data(), num_large, config.model.vocab_size, batch_size, engine); + generator->search_->SetLogits(Generators::cpu_span(logits_cpu.data(), config.model.vocab_size * batch_size)); auto start = std::chrono::high_resolution_clock::now(); generator->search_->SampleTopK(k, 1.0f); @@ -90,19 +94,21 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopKCpu) { TEST(Benchmarks, BenchmarkRandomizedSamplingTopPAndKCpu) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int vocab_size = 32000; // vocab size of llama int batch_size = 1; float p = 0.95f; int k = 5; std::vector input_ids{0, 1, 2, 3, 4}; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 32000; // vocab size of llama + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; - std::vector logits_cpu(vocab_size * batch_size); + std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution<> dist(5, 25); @@ -111,8 +117,8 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPAndKCpu) { for (int i = 0; i < num_iter; i++) { int num_large = dist(engine); auto generator = Generators::CreateGenerator(*model, *params); - CreateRandomLogits(logits_cpu.data(), num_large, vocab_size, batch_size, engine); - generator->search_->SetLogits(Generators::cpu_span(logits_cpu.data(), vocab_size * batch_size)); + CreateRandomLogits(logits_cpu.data(), num_large, config.model.vocab_size, batch_size, engine); + generator->search_->SetLogits(Generators::cpu_span(logits_cpu.data(), config.model.vocab_size * batch_size)); auto start = std::chrono::high_resolution_clock::now(); generator->search_->SampleTopKTopP(k, p, 1.0f); @@ -131,31 +137,33 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPAndKCpu) { TEST(Benchmarks, BenchmarkRandomizedSamplingTopPCuda) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int vocab_size = 32000; // vocab size of llama int batch_size = 1; std::vector input_ids{0, 1, 2, 3, 4}; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 32000; // vocab size of llama + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; - std::vector cpu_logits(vocab_size * batch_size); + std::vector cpu_logits(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution<> dist(1, 25); - auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); - auto indices_buffer = Generators::CudaMallocArray(vocab_size * batch_size); + auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); + auto indices_buffer = Generators::CudaMallocArray(config.model.vocab_size * batch_size); double total_time = 0.0; int num_iter = 1000; for (int i = 0; i < num_iter; i++) { int num_large = dist(engine); auto generator = Generators::CreateGenerator(*model, *params); - LaunchGeometricDecayKernel(logits_gpu.get(), vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); - LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), vocab_size, batch_size, params->cuda_stream); - cudaMemcpy(cpu_logits.data(), logits_gpu.get(), vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost); - generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); + LaunchGeometricDecayKernel(logits_gpu.get(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); + LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), config.model.vocab_size, batch_size, params->cuda_stream); + cudaMemcpy(cpu_logits.data(), logits_gpu.get(), config.model.vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost); + generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), config.model.vocab_size * batch_size)); cudaStreamSynchronize(params->cuda_stream); auto start = std::chrono::high_resolution_clock::now(); @@ -174,20 +182,22 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPCuda) { TEST(Benchmarks, BenchmarkRandomizedSamplingTopKCuda) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int vocab_size = 32000; // vocab size of llama int batch_size = 1; int k = 5; std::vector input_ids{0, 1, 2, 3, 4}; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 32000; // vocab size of llama + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; - auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); - auto indices_buffer = Generators::CudaMallocArray(vocab_size * batch_size); - std::vector cpu_logits(vocab_size * batch_size); + auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); + auto indices_buffer = Generators::CudaMallocArray(config.model.vocab_size * batch_size); + std::vector cpu_logits(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution<> dist(1, 25); @@ -196,10 +206,10 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopKCuda) { for (int i = 0; i < num_iter; i++) { int num_large = dist(engine); auto generator = Generators::CreateGenerator(*model, *params); - LaunchGeometricDecayKernel(logits_gpu.get(), vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); - LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), vocab_size, batch_size, params->cuda_stream); - cudaMemcpy(cpu_logits.data(), logits_gpu.get(), vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost); - generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); + LaunchGeometricDecayKernel(logits_gpu.get(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); + LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), config.model.vocab_size, batch_size, params->cuda_stream); + cudaMemcpy(cpu_logits.data(), logits_gpu.get(), config.model.vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost); + generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), config.model.vocab_size * batch_size)); cudaStreamSynchronize(params->cuda_stream); auto start = std::chrono::high_resolution_clock::now(); generator->search_->SampleTopK(k, 1.0f); @@ -214,21 +224,23 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopKCuda) { TEST(Benchmarks, BenchmarkRandomizedSamplingTopPAndKCuda) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int vocab_size = 32000; // vocab size of llama int batch_size = 1; float p = 0.95f; int k = 5; std::vector input_ids{0, 1, 2, 3, 4}; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 32000; // vocab size of llama + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; - auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); - auto indices_buffer = Generators::CudaMallocArray(vocab_size * batch_size); - std::vector cpu_logits(vocab_size * batch_size); + auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); + auto indices_buffer = Generators::CudaMallocArray(config.model.vocab_size * batch_size); + std::vector cpu_logits(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution<> dist(1, 25); @@ -237,10 +249,10 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPAndKCuda) { for (int i = 0; i < num_iter; i++) { auto generator = Generators::CreateGenerator(*model, *params); int num_large = dist(engine); - LaunchGeometricDecayKernel(logits_gpu.get(), vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); - LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), vocab_size, batch_size, params->cuda_stream); - cudaMemcpy(cpu_logits.data(), logits_gpu.get(), vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost); - generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); + LaunchGeometricDecayKernel(logits_gpu.get(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); + LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), config.model.vocab_size, batch_size, params->cuda_stream); + cudaMemcpy(cpu_logits.data(), logits_gpu.get(), config.model.vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost); + generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), config.model.vocab_size * batch_size)); cudaStreamSynchronize(params->cuda_stream); auto start = std::chrono::high_resolution_clock::now(); @@ -259,19 +271,21 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPAndKCuda) { TEST(Benchmarks, BenchmarkRandomizedSelectTopCuda) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int vocab_size = 32000; // vocab size of llama int batch_size = 12; std::vector input_ids{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; // Needs to match batch_size - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 32000; // vocab size of llama + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; - auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); - auto indices_buffer = Generators::CudaMallocArray(vocab_size * batch_size); - std::vector cpu_logits(vocab_size * batch_size); + auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); + auto indices_buffer = Generators::CudaMallocArray(config.model.vocab_size * batch_size); + std::vector cpu_logits(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution<> dist(1, 25); @@ -280,10 +294,10 @@ TEST(Benchmarks, BenchmarkRandomizedSelectTopCuda) { for (int i = 0; i < num_iter; i++) { int num_large = dist(engine); auto generator = Generators::CreateGenerator(*model, *params); - LaunchGeometricDecayKernel(logits_gpu.get(), vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); - LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), vocab_size, batch_size, params->cuda_stream); - cudaMemcpy(cpu_logits.data(), logits_gpu.get(), vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost); - generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); + LaunchGeometricDecayKernel(logits_gpu.get(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); + LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), config.model.vocab_size, batch_size, params->cuda_stream); + cudaMemcpy(cpu_logits.data(), logits_gpu.get(), config.model.vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost); + generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), config.model.vocab_size * batch_size)); cudaStreamSynchronize(params->cuda_stream); auto start = std::chrono::high_resolution_clock::now(); generator->search_->SelectTop(); diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 85c151d2e..7a6505255 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -22,15 +22,15 @@ TEST(SamplingTests, BatchedSamplingTopPCpu) { 0.1f, 0.1f, 0.6f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.6f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.6f}; - int vocab_size = 5; - int batch_size = 4; - auto params = Generators::CreateGeneratorParams(); + Generators::Config config; + config.model.vocab_size = 5; + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; params->search.top_p = 0.25f; - params->batch_size = batch_size; + params->batch_size = 4; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto generator = Generators::CreateGenerator(*model, *params); @@ -50,15 +50,16 @@ TEST(SamplingTests, BatchedSamplingTopKCpu) { 0.25f, 2.0f, 1.25f, 1.5f, 0.25f, 0.25f, 2.0f, 0.25f, 1.5f, 1.25f, 1.25f, 0.25f, 1.5f, 0.25f, 2.0f}; - int vocab_size = 5; + Generators::Config config; + config.model.vocab_size = 5; + int batch_size = 4; - auto params = Generators::CreateGeneratorParams(); + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; params->search.top_k = 2; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; auto generator = Generators::CreateGenerator(*model, *params); @@ -71,7 +72,7 @@ TEST(SamplingTests, BatchedSamplingTopKCpu) { auto next_tokens = generator->search_->GetNextTokens().GetCPU(); for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = logits_cpu[next_token + vocab_size * b]; + auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; EXPECT_GT(next_token_score, 1.25f); } } @@ -83,16 +84,18 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCpu) { 0.25f, 2.0f, 1.25f, 1.5f, 0.25f, 0.25f, 2.0f, 0.25f, 1.5f, 1.25f, 1.25f, 0.25f, 1.5f, 0.25f, 2.0f}; - int vocab_size = 5; + + Generators::Config config; + config.model.vocab_size = 5; + int batch_size = 4; - auto params = Generators::CreateGeneratorParams(); + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; params->search.top_k = 2; params->search.top_p = 0.25f; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; auto generator = Generators::CreateGenerator(*model, *params); @@ -104,7 +107,7 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCpu) { auto next_tokens = generator->search_->GetNextTokens().GetCPU(); for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = logits_cpu[next_token + vocab_size * b]; + auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; EXPECT_GT(next_token_score, 1.25f); } } @@ -133,19 +136,21 @@ void CreateRandomLogits(float* logits, int num_large, int vocab_size, int batch_ TEST(SamplingTests, RandomizedSamplingTopPCpu) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int vocab_size = 32000; // vocab size of llama int batch_size = 5; std::vector input_ids{0, 1, 2, 3, 4}; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 32000; // vocab size of llama + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; params->search.top_p = 0.95f; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; - std::vector logits_cpu(vocab_size * batch_size); + std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution<> dist(1, 25); @@ -153,7 +158,7 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) { for (int i = 0; i < num_iter; i++) { auto generator = Generators::CreateGenerator(*model, *params); int num_large = dist(engine); - CreateRandomLogits(logits_cpu.data(), num_large, vocab_size, batch_size, engine); + CreateRandomLogits(logits_cpu.data(), num_large, config.model.vocab_size, batch_size, engine); auto logits_copy = logits_cpu; generator->search_->SetLogits(Generators::cpu_span(logits_copy)); generator->computed_logits_ = true; @@ -162,7 +167,7 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) { // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = logits_cpu[next_token + vocab_size * b]; + auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; EXPECT_GT(next_token_score, 1.0f); } } @@ -170,20 +175,22 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) { TEST(SamplingTests, RandomizedSamplingTopKCpu) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int vocab_size = 32000; // vocab size of llama int batch_size = 5; int k = 5; std::vector input_ids{0, 1, 2, 3, 4}; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 32000; // vocab size of llama + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; params->search.top_k = k; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; - std::vector logits_cpu(vocab_size * batch_size); + std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution<> dist(5, 25); @@ -191,7 +198,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { for (int i = 0; i < num_iter; i++) { int num_large = dist(engine); auto generator = Generators::CreateGenerator(*model, *params); - CreateRandomLogits(logits_cpu.data(), num_large, vocab_size, batch_size, engine); + CreateRandomLogits(logits_cpu.data(), num_large, config.model.vocab_size, batch_size, engine); auto logits_copy = logits_cpu; generator->search_->SetLogits(Generators::cpu_span(logits_copy)); generator->computed_logits_ = true; @@ -200,7 +207,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = logits_cpu[next_token + vocab_size * b]; + auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; EXPECT_GT(next_token_score, 10.0f); } } @@ -208,22 +215,24 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int vocab_size = 32000; // vocab size of llama int batch_size = 5; float p = 0.95f; int k = 5; std::vector input_ids{0, 1, 2, 3, 4}; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 32000; // vocab size of llama + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; params->search.top_k = k; params->search.top_p = p; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; - std::vector logits_cpu(vocab_size * batch_size); + std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution<> dist(5, 25); @@ -231,7 +240,7 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) { for (int i = 0; i < num_iter; i++) { int num_large = dist(engine); auto generator = Generators::CreateGenerator(*model, *params); - CreateRandomLogits(logits_cpu.data(), num_large, vocab_size, batch_size, engine); + CreateRandomLogits(logits_cpu.data(), num_large, config.model.vocab_size, batch_size, engine); auto logits_copy = logits_cpu; generator->search_->SetLogits(Generators::cpu_span(logits_copy)); generator->computed_logits_ = true; @@ -240,7 +249,7 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) { // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = logits_cpu[next_token + vocab_size * b]; + auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; EXPECT_GT(next_token_score, 10.0f); } } @@ -259,15 +268,17 @@ TEST(SamplingTests, BatchedSamplingTopPCuda) { 0.1f, 0.1f, 0.1f, 0.6f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.6f}; auto logits_gpu = Generators::CudaMallocArray(logits_cpu.size()); - int vocab_size = 5; int batch_size = 4; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 5; + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; params->search.top_p = 0.25f; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; cudaMemcpyAsync(logits_gpu.get(), logits_cpu.data(), logits_cpu.size() * sizeof(float), cudaMemcpyHostToDevice, params->cuda_stream); @@ -289,15 +300,17 @@ TEST(SamplingTests, BatchedSamplingTopKCuda) { 0.25f, 2.0f, 0.25f, 1.5f, 1.25f, 1.25f, 0.25f, 1.5f, 0.25f, 2.0f}; auto logits_gpu = Generators::CudaMallocArray(logits_cpu.size()); - int vocab_size = 5; int batch_size = 4; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 5; + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; params->search.top_k = 2; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; cudaMemcpyAsync(logits_gpu.get(), logits_cpu.data(), logits_cpu.size() * sizeof(float), cudaMemcpyHostToDevice, params->cuda_stream); @@ -310,7 +323,7 @@ TEST(SamplingTests, BatchedSamplingTopKCuda) { auto next_tokens = generator->search_->GetNextTokens().GetCPU(); for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = logits_cpu[next_token + vocab_size * b]; + auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; EXPECT_GT(next_token_score, 1.25f); } } @@ -323,16 +336,18 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCuda) { 0.25f, 2.0f, 0.25f, 1.5f, 1.25f, 1.25f, 0.25f, 1.5f, 0.25f, 2.0f}; auto logits_gpu = Generators::CudaMallocArray(logits_cpu.size()); - int vocab_size = 5; int batch_size = 4; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 5; + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; params->search.top_k = 2; params->search.top_p = 0.25f; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; cudaMemcpyAsync(logits_gpu.get(), logits_cpu.data(), logits_cpu.size() * sizeof(float), cudaMemcpyHostToDevice, params->cuda_stream); @@ -345,28 +360,31 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCuda) { auto next_tokens = generator->search_->GetNextTokens().GetCPU(); for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = logits_cpu[next_token + vocab_size * b]; + auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; EXPECT_GT(next_token_score, 1.25f); } } TEST(SamplingTests, RandomizedSamplingTopPCuda) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int vocab_size = 32000; // vocab size of llama int batch_size = 5; std::vector input_ids{0, 1, 2, 3, 4}; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 32000; // vocab size of llama + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; params->search.top_p = 0.95f; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; - auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); - auto indices_buffer = Generators::CudaMallocHostArray(vocab_size * batch_size); - float* cpu_logits = new float[vocab_size * batch_size]; + auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); + auto indices_buffer = Generators::CudaMallocHostArray(config.model.vocab_size * batch_size); + std::vector logits_cpu(config.model.vocab_size * batch_size); + std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution<> dist(1, 25); @@ -374,10 +392,10 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) { for (int i = 0; i < num_iter; i++) { int num_large = dist(engine); auto generator = Generators::CreateGenerator(*model, *params); - LaunchGeometricDecayKernel(logits_gpu.get(), vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); - LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), vocab_size, batch_size, params->cuda_stream); - cudaMemcpyAsync(cpu_logits, logits_gpu.get(), vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost, params->cuda_stream); - generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); + LaunchGeometricDecayKernel(logits_gpu.get(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); + LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), config.model.vocab_size, batch_size, params->cuda_stream); + cudaMemcpyAsync(logits_cpu.data(), logits_gpu.get(), config.model.vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost, params->cuda_stream); + generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), config.model.vocab_size * batch_size)); generator->computed_logits_ = true; generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); @@ -385,7 +403,7 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) { // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = cpu_logits[next_token + vocab_size * b]; + auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; EXPECT_GT(next_token_score, 0.0001f); } } @@ -393,33 +411,35 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) { TEST(SamplingTests, RandomizedSamplingTopKCuda) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int vocab_size = 32000; // vocab size of llama int batch_size = 5; int k = 5; std::vector input_ids{0, 1, 2, 3, 4}; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 32000; // vocab size of llama + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; params->search.top_k = k; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; - auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); - auto indices_buffer = Generators::CudaMallocHostArray(vocab_size * batch_size); - float* cpu_logits = new float[vocab_size * batch_size]; + auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); + auto indices_buffer = Generators::CudaMallocHostArray(config.model.vocab_size * batch_size); + std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution<> dist(1, 25); int num_iter = 100; for (int i = 0; i < num_iter; i++) { int num_large = dist(engine); - LaunchGeometricDecayKernel(logits_gpu.get(), vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); - LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), vocab_size, batch_size, params->cuda_stream); - cudaMemcpyAsync(cpu_logits, logits_gpu.get(), vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost, params->cuda_stream); + LaunchGeometricDecayKernel(logits_gpu.get(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); + LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), config.model.vocab_size, batch_size, params->cuda_stream); + cudaMemcpyAsync(logits_cpu.data(), logits_gpu.get(), config.model.vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost, params->cuda_stream); auto generator = Generators::CreateGenerator(*model, *params); - generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); + generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), config.model.vocab_size * batch_size)); generator->computed_logits_ = true; generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); @@ -427,7 +447,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = cpu_logits[next_token + vocab_size * b]; + auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; EXPECT_GT(next_token_score, 10.0f); } } @@ -435,24 +455,26 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int vocab_size = 32000; // vocab size of llama int batch_size = 5; float p = 0.95f; int k = 5; std::vector input_ids{0, 1, 2, 3, 4}; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 32000; // vocab size of llama + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->search.do_sample = true; params->search.top_k = k; params->search.top_p = p; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; - auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); - auto indices_buffer = Generators::CudaMallocHostArray(vocab_size * batch_size); - float* cpu_logits = new float[vocab_size * batch_size]; + auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); + auto indices_buffer = Generators::CudaMallocHostArray(config.model.vocab_size * batch_size); + std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution<> dist(1, 25); @@ -460,10 +482,10 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) { for (int i = 0; i < num_iter; i++) { int num_large = dist(engine); auto generator = Generators::CreateGenerator(*model, *params); - LaunchGeometricDecayKernel(logits_gpu.get(), vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); - LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), vocab_size, batch_size, params->cuda_stream); - cudaMemcpyAsync(cpu_logits, logits_gpu.get(), vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost, params->cuda_stream); - generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); + LaunchGeometricDecayKernel(logits_gpu.get(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); + LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), config.model.vocab_size, batch_size, params->cuda_stream); + cudaMemcpyAsync(logits_cpu.data(), logits_gpu.get(), config.model.vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost, params->cuda_stream); + generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), config.model.vocab_size * batch_size)); generator->computed_logits_ = true; generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); @@ -471,7 +493,7 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) { // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; - auto next_token_score = cpu_logits[next_token + vocab_size * b]; + auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; EXPECT_GT(next_token_score, 10.0f); } } @@ -479,39 +501,41 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) { TEST(SamplingTests, RandomizedSamplingSelectTopCuda) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - int vocab_size = 32000; // vocab size of llama int batch_size = 5; std::vector input_ids{0, 1, 2, 3, 4}; - auto params = Generators::CreateGeneratorParams(); + + Generators::Config config; + config.model.vocab_size = 32000; // vocab size of llama + + auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; params->batch_size = batch_size; params->sequence_length = 1; - params->vocab_size = vocab_size; params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; - auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); - auto indices_buffer = Generators::CudaMallocHostArray(vocab_size * batch_size); - float* cpu_logits = new float[vocab_size * batch_size]; + auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); + auto indices_buffer = Generators::CudaMallocHostArray(config.model.vocab_size * batch_size); + std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution<> dist(1, 25); int num_iter = 100; for (int i = 0; i < num_iter; i++) { int num_large = dist(engine); - LaunchGeometricDecayKernel(logits_gpu.get(), vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); - LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), vocab_size, batch_size, params->cuda_stream); - cudaMemcpyAsync(cpu_logits, logits_gpu.get(), vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost, params->cuda_stream); + LaunchGeometricDecayKernel(logits_gpu.get(), config.model.vocab_size, batch_size, num_large, 20.0f, params->cuda_stream); + LaunchFisherYatesKernel(logits_gpu.get(), indices_buffer.get(), config.model.vocab_size, batch_size, params->cuda_stream); + cudaMemcpyAsync(logits_cpu.data(), logits_gpu.get(), config.model.vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost, params->cuda_stream); auto generator = Generators::CreateGenerator(*model, *params); - generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); + generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), config.model.vocab_size * batch_size)); generator->computed_logits_ = true; generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); cudaStreamSynchronize(params->cuda_stream); // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { - float max_score = *std::max_element(cpu_logits + vocab_size * b, cpu_logits + vocab_size * (b + 1)); + float max_score = *std::max_element(logits_cpu.begin() + config.model.vocab_size * b, logits_cpu.begin() + config.model.vocab_size * (b + 1)); auto next_token = next_tokens[b]; - auto next_token_score = cpu_logits[next_token + vocab_size * b]; + auto next_token_score = logits_cpu[next_token + config.model.vocab_size * b]; EXPECT_EQ(next_token_score, max_score); } }