diff --git a/benchmark/c/main.cpp b/benchmark/c/main.cpp index e924a26f8..b9b35728f 100644 --- a/benchmark/c/main.cpp +++ b/benchmark/c/main.cpp @@ -121,11 +121,15 @@ std::string GeneratePrompt(size_t num_prompt_tokens, const OgaModel& model, cons auto params = OgaGeneratorParams::Create(model); params->SetSearchOption("max_length", static_cast(num_prompt_tokens)); params->SetSearchOption("min_length", static_cast(num_prompt_tokens)); - params->SetInputSequences(*base_prompt_sequences); - auto output_sequences = model.Generate(*params); - const auto output_sequence_length = output_sequences->SequenceCount(0); - const auto* output_sequence_data = output_sequences->SequenceData(0); + auto generator = OgaGenerator::Create(model, *params); + generator->AddInputSequences(*base_prompt_sequences); + while (!generator->IsDone()) { + generator->GenerateNextToken(); + } + + const auto output_sequence_length = generator->GetSequenceCount(0); + const auto* output_sequence_data = generator->GetSequenceData(0); return std::string{tokenizer.Decode(output_sequence_data, output_sequence_length)}; } @@ -151,7 +155,6 @@ void RunBenchmark(const benchmark::Options& opts) { auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", static_cast(num_tokens)); params->SetSearchOption("min_length", static_cast(num_tokens)); - params->SetInputSequences(*prompt_sequences); return params; }; @@ -160,13 +163,17 @@ void RunBenchmark(const benchmark::Options& opts) { // warmup if (opts.verbose) std::cout << "Running warmup iterations (" << opts.num_warmup_iterations << ")...\n"; for (size_t i = 0; i < opts.num_warmup_iterations; ++i) { - auto output_sequences = model->Generate(*generator_params); + auto generator = OgaGenerator::Create(*model, *generator_params); + generator->AddInputSequences(*prompt_sequences); + while (!generator->IsDone()) { + generator->GenerateNextToken(); + } if (opts.verbose && i == 0) { // show prompt and output on first iteration std::cout << "Prompt:\n\t" << prompt << "\n"; - const auto output_sequence_length = output_sequences->SequenceCount(0); - const auto* output_sequence_data = output_sequences->SequenceData(0); + const auto output_sequence_length = generator->GetSequenceCount(0); + const auto* output_sequence_data = generator->GetSequenceData(0); const auto output = tokenizer->Decode(output_sequence_data, output_sequence_length); std::cout << "Output:\n\t" << output << "\n"; } @@ -188,7 +195,7 @@ void RunBenchmark(const benchmark::Options& opts) { { Timing prompt_processing_timing{prompt_processing_times}; - generator->ComputeLogits(); + generator->AddInputSequences(*prompt_sequences); } { @@ -199,11 +206,6 @@ void RunBenchmark(const benchmark::Options& opts) { while (!generator->IsDone()) { { Timing token_gen_timing{token_gen_times}; - generator->ComputeLogits(); - } - - { - Timing sampling_timing{sampling_times}; generator->GenerateNextToken(); } } diff --git a/examples/c/src/phi3.cpp b/examples/c/src/phi3.cpp index ea1562593..41845fcdc 100644 --- a/examples/c/src/phi3.cpp +++ b/examples/c/src/phi3.cpp @@ -93,12 +93,11 @@ void CXX_API(const char* model_path) { std::cout << "Generating response..." << std::endl; auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", 1024); - params->SetInputSequences(*sequences); auto generator = OgaGenerator::Create(*model, *params); + generator->AddInputSequences(*sequences); while (!generator->IsDone()) { - generator->ComputeLogits(); generator->GenerateNextToken(); if (is_first_token) { @@ -179,13 +178,12 @@ void C_API(const char* model_path) { OgaGeneratorParams* params; CheckResult(OgaCreateGeneratorParams(model, ¶ms)); CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 1024)); - CheckResult(OgaGeneratorParamsSetInputSequences(params, sequences)); OgaGenerator* generator; CheckResult(OgaCreateGenerator(model, params, &generator)); + CheckResult(OgaGenerator_AddInputSequences(generator, sequences)); while (!OgaGenerator_IsDone(generator)) { - CheckResult(OgaGenerator_ComputeLogits(generator)); CheckResult(OgaGenerator_GenerateNextToken(generator)); if (is_first_token) { diff --git a/examples/c/src/phi3v.cpp b/examples/c/src/phi3v.cpp index 9d78be599..26026fc46 100644 --- a/examples/c/src/phi3v.cpp +++ b/examples/c/src/phi3v.cpp @@ -76,7 +76,6 @@ void CXX_API(const char* model_path) { auto generator = OgaGenerator::Create(*model, *params); while (!generator->IsDone()) { - generator->ComputeLogits(); generator->GenerateNextToken(); const auto num_tokens = generator->GetSequenceCount(0); @@ -162,7 +161,6 @@ void C_API(const char* model_path) { CheckResult(OgaCreateGenerator(model, params, &generator)); while (!OgaGenerator_IsDone(generator)) { - CheckResult(OgaGenerator_ComputeLogits(generator)); CheckResult(OgaGenerator_GenerateNextToken(generator)); const int32_t num_tokens = OgaGenerator_GetSequenceCount(generator, 0); diff --git a/examples/python/model-generate.py b/examples/python/model-generate.py index 0a97f25b4..3e7cd2769 100644 --- a/examples/python/model-generate.py +++ b/examples/python/model-generate.py @@ -12,9 +12,9 @@ def main(args): if hasattr(args, 'prompts'): prompts = args.prompts else: - prompts = ["I like walking my cute dog", - "What is the best restaurant in town?", - "Hello, how are you today?"] + prompts = ["The first 4 digits of pi are", + "The square root of 2 is", + "The first 6 numbers of the Fibonacci sequence are",] if args.chat_template: if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1: @@ -28,6 +28,7 @@ def main(args): params = og.GeneratorParams(model) search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args} + search_options['batch_size'] = 3 if (args.verbose): print(f'Args: {args}') if (args.verbose): print(f'Search options: {search_options}') @@ -37,22 +38,28 @@ def main(args): params.try_graph_capture_with_max_batch_size(len(prompts)) if args.batch_size_for_cuda_graph: params.try_graph_capture_with_max_batch_size(args.batch_size_for_cuda_graph) - params.input_ids = input_tokens if args.verbose: print("GeneratorParams created") + generator = og.Generator(model, params) + if args.verbose: print("Generator created") + + generator.add_input_tokens(input_tokens) + if args.verbose: print("Input tokens added") + if args.verbose: print("Generating tokens ...\n") start_time = time.time() - output_tokens = model.generate(params) + while not generator.is_done(): + generator.generate_next_token() run_time = time.time() - start_time for i in range(len(prompts)): print(f'Prompt #{i}: {prompts[i]}') print() - print(tokenizer.decode(output_tokens[i])) + print(tokenizer.decode(generator.get_sequence(i))) print() print() - total_tokens = sum(len(x) for x in output_tokens) + total_tokens = sum(len(generator.get_sequence(i)) for i in range(len(prompts))) print(f"Tokens: {total_tokens} Time: {run_time:.2f} Tokens per second: {total_tokens/run_time:.2f}") print() diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py index 4532f307a..23ce32391 100644 --- a/examples/python/model-qa.py +++ b/examples/python/model-qa.py @@ -16,6 +16,7 @@ def main(args): if args.verbose: print() search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args} + search_options['batch_size'] = 1 if args.verbose: print(search_options) @@ -24,6 +25,16 @@ def main(args): print("Error, chat template must have exactly one pair of curly braces, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'") exit(1) + params = og.GeneratorParams(model) + params.set_search_options(**search_options) + generator = og.Generator(model, params) + + # Set system prompt + system_prompt = args.system_prompt + system_tokens = tokenizer.encode(system_prompt) + generator.add_input_tokens(system_tokens) + system_prompt_length = len(system_tokens) + # Keep asking for input prompts in a loop while True: text = input("Input: ") @@ -39,11 +50,8 @@ def main(args): prompt = f'{args.chat_template.format(input=text)}' input_tokens = tokenizer.encode(prompt) - - params = og.GeneratorParams(model) - params.set_search_options(**search_options) - params.input_ids = input_tokens - generator = og.Generator(model, params) + + generator.add_input_tokens(input_tokens) if args.verbose: print("Generator created") if args.verbose: print("Running generation loop ...") @@ -56,7 +64,6 @@ def main(args): try: while not generator.is_done(): - generator.compute_logits() generator.generate_next_token() if args.timings: if first: @@ -71,14 +78,14 @@ def main(args): print() print() - # Delete the generator to free the captured graph for the next generator, if graph capture is enabled - del generator - if args.timings: prompt_time = first_token_timestamp - started_timestamp run_time = time.time() - first_token_timestamp print(f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps") - + + # Rewind the generator to the system prompt + if args.rewind: + generator.rewind_to_length(system_prompt_length) if __name__ == "__main__": parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai") @@ -93,5 +100,7 @@ def main(args): parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false') parser.add_argument('-g', '--timings', action='store_true', default=False, help='Print timing information for each generation step. Defaults to false') parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}') + parser.add_argument('-s', '--system_prompt', type=str, default='You are a helpful assistant. You are friendly, courteous, and professional. All your responses must end with an exclamation point!', help='System prompt to use for the prompt.') + parser.add_argument('-re', '--rewind', action='store_true', default=False, help='Rewind to the system prompt after each generation. Defaults to false') args = parser.parse_args() main(args) \ No newline at end of file diff --git a/src/beam_search_scorer.cpp b/src/beam_search_scorer.cpp index 033118285..a0aadea3d 100644 --- a/src/beam_search_scorer.cpp +++ b/src/beam_search_scorer.cpp @@ -44,13 +44,13 @@ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_length) con } BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters) - : batch_size_{parameters.batch_size}, + : batch_size_{parameters.search.batch_size}, num_beams_{parameters.search.num_beams}, max_length_{parameters.search.max_length}, pad_token_id_{parameters.config.model.pad_token_id}, eos_token_id_{parameters.config.model.eos_token_id}, early_stopping_{parameters.search.early_stopping}, - not_done_count_{parameters.batch_size} { + not_done_count_{parameters.search.batch_size} { size_t const batch_beam_size = static_cast(batch_size_) * num_beams_; std::span beams; @@ -65,7 +65,10 @@ BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters) next_beam_indices_ptr_ = AllocateArray(batch_beam_size, &next_beam_indices_); // Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length. - size_t const per_beam = (max_length_ * (max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2; + // TODO(aciddelgado): Initialize in first update function type thing. + // size_t const per_beam = (max_length_ * (max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2; + size_t const per_beam = (max_length_ * (max_length_ + 1)) / 2; + hypothesis_buffer_ptr_ = AllocateArray(batch_beam_size * per_beam, &hypothesis_buffer_); memset(next_beam_scores_.data(), 0, next_beam_scores_.size_bytes()); @@ -73,7 +76,7 @@ BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters) // Initialize score of first beam of each group with 0 and the rest with -1e9. // This ensures that the beams in the same group don't produce same tokens every time. std::span const beam_scores = next_beam_scores_; - for (int i = 0; i < parameters.batch_size; i++) { + for (int i = 0; i < parameters.search.batch_size; i++) { for (int j = 1; j < parameters.search.num_beams; j++) { beam_scores[i * parameters.search.num_beams + j] = -1e9; } diff --git a/src/beam_search_scorer_cuda.cpp b/src/beam_search_scorer_cuda.cpp index c61b69111..efbfbcbef 100644 --- a/src/beam_search_scorer_cuda.cpp +++ b/src/beam_search_scorer_cuda.cpp @@ -9,13 +9,13 @@ namespace Generators { BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters) : stream_{parameters.cuda_stream} { state_cpu_ = CudaMallocHostArray(1); - state_cpu_->batch_size_ = static_cast(parameters.batch_size); + state_cpu_->batch_size_ = static_cast(parameters.search.batch_size); state_cpu_->num_beams_ = static_cast(parameters.search.num_beams); state_cpu_->max_length_ = static_cast(parameters.search.max_length); state_cpu_->pad_token_id_ = parameters.config.model.pad_token_id; state_cpu_->eos_token_id_ = parameters.config.model.eos_token_id; state_cpu_->early_stopping_ = parameters.search.early_stopping; - state_cpu_->not_done_count_ = parameters.batch_size; + state_cpu_->not_done_count_ = parameters.search.batch_size; state_cpu_->hypothesis_buffer_used_ = 0; state_gpu_ = CudaMallocArray(1); cudaMemcpyAsync(state_gpu_.get(), state_cpu_.get(), sizeof(cuda::BeamScorerState), ::cudaMemcpyHostToDevice, stream_); @@ -34,10 +34,12 @@ BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters) next_beam_indices_cpu_ptr_ = std::make_unique(batch_beam_size); next_beam_indices_cpu_ = cpu_span(next_beam_indices_cpu_ptr_.get(), batch_beam_size); - cuda::LaunchInitScoresKernel(next_beam_scores_.data(), parameters.batch_size, parameters.search.num_beams, stream_); + cuda::LaunchInitScoresKernel(next_beam_scores_.data(), parameters.search.batch_size, parameters.search.num_beams, stream_); // Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length. - size_t per_beam = (state_cpu_->max_length_ * (state_cpu_->max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2; + // TODO(aciddelgado): Initialize in first update function type thing. + // size_t per_beam = (state_cpu_->max_length_ * (state_cpu_->max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2; + size_t per_beam = (state_cpu_->max_length_ * (state_cpu_->max_length_ + 1)) / 2; hypothesis_buffer_ptr_ = CudaMallocArray(batch_beam_size * per_beam, &hypothesis_buffer_); } diff --git a/src/config.cpp b/src/config.cpp index 1e07e7f9f..f5cbb1434 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -537,6 +537,8 @@ struct Search_Element : JSON::Element { v_.min_length = static_cast(value); } else if (name == "max_length") { v_.max_length = static_cast(value); + } else if (name == "batch_size") { + v_.batch_size = static_cast(value); } else if (name == "num_beams") { v_.num_beams = static_cast(value); } else if (name == "num_return_sequences") { diff --git a/src/config.h b/src/config.h index 4a14cdd6e..1c10507c4 100644 --- a/src/config.h +++ b/src/config.h @@ -144,6 +144,7 @@ struct Config { bool do_sample{}; // True to do randomized sampling through top_k and top_p, if false, the top logit score is chosen int min_length{}; int max_length{}; // If omitted or 0 in json file, will be set to model.context_length on load + int batch_size{1}; int num_beams{1}; // 1 means no beam search. int num_return_sequences{1}; float repetition_penalty{1.0f}; // 1.0 means no penalty. diff --git a/src/generators.cpp b/src/generators.cpp index 36f12a2bd..9fef9ffc1 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -4,6 +4,7 @@ #include "generators.h" #include "sequences.h" #include "models/model.h" +#include "models/decoder_only.h" #include "search.h" #if USE_CUDA #include "search_cuda.h" @@ -100,14 +101,10 @@ void GeneratorParams::TryGraphCapture(int max_bs) { } } +// TODO(aciddelgado): Does this work? void GeneratorParams::SetInputs(const NamedTensors& named_tensors) { for (const auto& [name, tensor] : named_tensors) { - if (name == Config::Defaults::InputIdsName) { - input_ids = std::span(tensor->ort_tensor_->GetTensorMutableData(), - tensor->ort_tensor_->GetTensorTypeAndShapeInfo()->GetElementCount()); - batch_size = static_cast(tensor->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape()[0]); - sequence_length = static_cast(input_ids.size()) / batch_size; - } else { + if (name != Config::Defaults::InputIdsName) { // If the nominal name is found in the map, use the graph name. // Else, use the nominal name as the graph name. [[maybe_unused]] const auto [graph_name, found] = config.GetGraphName(name); @@ -140,24 +137,28 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ throw std::runtime_error("search max_length is 0"); if (params.search.max_length > model.config_->model.context_length) throw std::runtime_error("max_length (" + std::to_string(params.search.max_length) + ") cannot be greater than model context_length (" + std::to_string(model.config_->model.context_length) + ")"); - if (params.batch_size < 1) - throw std::runtime_error("batch_size must be 1 or greater, is " + std::to_string(params.batch_size)); + if (params.search.batch_size < 1) + throw std::runtime_error("batch_size must be 1 or greater, is " + std::to_string(params.search.batch_size)); if (params.config.model.vocab_size < 1) throw std::runtime_error("vocab_size must be 1 or greater, is " + std::to_string(params.config.model.vocab_size)); - if (params.sequence_length >= params.search.max_length) - throw std::runtime_error("input sequence_length (" + std::to_string(params.sequence_length) + ") is >= max_length (" + std::to_string(params.search.max_length) + ")"); - if (params.input_ids.empty() || params.input_ids.data() == nullptr) - throw std::runtime_error("input_ids not set in GeneratorParams"); search_ = CreateSearch(params); - state_ = model.CreateState(search_->GetSequenceLengths(), params); + state_ = model.CreateState(search_->GetSequenceLengths(), params); // Search sequence lengths set when creating state } -void Generator::ComputeLogits() { +void Generator::AddTokens(cpu_span input_ids) { + // TODO(aciddelgado): check for batch_size > 1 requires full rewind + search_->SetUserTokens(input_ids); + + computed_logits_ = false; + ComputeLogits(input_ids); +} + +void Generator::ComputeLogits(const RoamingArray& next_tokens) { if (computed_logits_) - throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken first"); + throw std::runtime_error("ComputeLogits called again without calling AddTokens or GenerateNextToken first"); - auto logits = state_->Run(search_->GetSequenceLength(), search_->GetNextTokens(), search_->GetNextIndices()); + auto logits = state_->Run(search_->GetSequenceLength(), next_tokens, search_->GetNextIndices()); if (g_log.enabled && g_log.model_logits) { auto& stream = Log("model_logits"); DumpSpan(stream, logits.GetCPU()); @@ -165,15 +166,13 @@ void Generator::ComputeLogits() { } search_->SetLogits(logits); computed_logits_ = true; - - auto& search = search_->params_->search; - search_->ApplyMinLength(search.min_length); - search_->ApplyRepetitionPenalty(search.repetition_penalty); } bool Generator::IsDone() const { - if (computed_logits_) - throw std::runtime_error("IsDone() can't be called in the middle of processing logits"); + // TODO(aciddelgado): Is this the correct approach to handling computed_logits_ now? + if (computed_logits_) { + return false; + } bool is_done = search_->IsDone(); if (is_done) { @@ -184,10 +183,14 @@ bool Generator::IsDone() const { } void Generator::GenerateNextToken() { - if (!computed_logits_) - throw std::runtime_error("Must call ComputeLogits before GenerateNextToken"); + // TODO(aciddelgado): check that AddTokens has been called at least once + if (!computed_logits_) { + ComputeLogits(search_->GetNextTokens()); + } computed_logits_ = false; auto& search = search_->params_->search; + search_->ApplyMinLength(search.min_length); + search_->ApplyRepetitionPenalty(search.repetition_penalty); if (g_log.enabled && g_log.generate_next_token) { auto& stream = Log("generate_next_token"); @@ -224,27 +227,21 @@ void Generator::GenerateNextToken() { } } -RoamingArray Generator::GetSequence(size_t index) const { - return search_->GetSequence(index); +void Generator::RewindToLength(size_t new_length) { + if (new_length > search_->GetSequenceLength()) + throw std::runtime_error("Cannot rewind to a length greater than the current sequence length"); + if (new_length == search_->GetSequenceLength()) + return; + size_t batch_size = search_->params_->search.batch_size; + if (batch_size > 1 && new_length != 0) + throw std::runtime_error("RewindToLength must be called with new_length=0 when batch_size > 1"); + search_->RewindTo(new_length); + state_->RewindTo(new_length); + computed_logits_ = false; } -TokenSequences Generate(const Model& model, const GeneratorParams& params) { - auto generator = CreateGenerator(model, params); - - while (!generator->IsDone()) { - generator->ComputeLogits(); - generator->GenerateNextToken(); - } - - TokenSequences result; - for (int i = 0; i < params.batch_size * params.search.num_return_sequences; i++) { - auto sequence = generator->search_->GetSequence(i); - auto sequence_cpu = sequence.GetCPU(); - - auto& v = result.emplace_back(); - v.assign(sequence_cpu.begin(), sequence_cpu.end()); - } - return result; +RoamingArray Generator::GetSequence(size_t index) const { + return search_->GetSequence(index); } } // namespace Generators diff --git a/src/generators.h b/src/generators.h index 488dd8fa9..9905cd5bf 100644 --- a/src/generators.h +++ b/src/generators.h @@ -63,18 +63,13 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec const Config& config; // The model outlives the GeneratorParams Config::Search search{config.search}; // Copy of the search parameters from the config - int batch_size{1}; int max_batch_size{0}; bool use_cuda_graph{}; - int sequence_length{}; - int BatchBeamSize() const { return search.num_beams * batch_size; } + int BatchBeamSize() const { return search.num_beams * search.batch_size; } DeviceType device_type{DeviceType::CPU}; cudaStream_t cuda_stream{}; - // TODO: Move this to a separate GPT struct - std::span input_ids; // Array of [batchsize][sequence_length] - struct Whisper { std::shared_ptr input_features; // float32 [batch_size, number_of_mels, number_of_frames] std::shared_ptr alignment_heads; // int32 [num_alignment_heads, 2] @@ -82,8 +77,6 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec std::variant inputs; - std::vector input_ids_owner; // Backing memory of input_ids in some cases - std::shared_ptr external_owner_; // Set to 'this' when created by the C API to preserve lifetime struct Input { @@ -106,8 +99,9 @@ struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone() const; - void ComputeLogits(); - void GenerateNextToken(); + virtual void AddTokens(cpu_span input_ids); + virtual void GenerateNextToken(); + virtual void RewindToLength(size_t new_length); // Rewind state to new_length RoamingArray GetSequence(size_t index) const; @@ -115,6 +109,9 @@ struct Generator : LeakChecked { std::unique_ptr state_; std::unique_ptr search_; bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio + + private: + void ComputeLogits(const RoamingArray& next_tokens); }; struct OrtGlobals { @@ -138,7 +135,6 @@ std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path); std::shared_ptr CreateGeneratorParams(const Model& model); std::shared_ptr CreateGeneratorParams(const Config& config); // For benchmarking purposes only std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params); -std::vector> Generate(const Model& model, const GeneratorParams& params); // Uses CreateGenerator and a simple loop to return the entire sequence float Float16ToFloat32(uint16_t v); // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction void top_k_indices(std::span top_k, std::span inputs); diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index 206549be3..c82f81eb1 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -9,8 +9,8 @@ DecoderOnly_Model::DecoderOnly_Model(std::unique_ptr config, OrtEnv& ort InitDeviceAllocator(*session_decoder_); } -std::unique_ptr DecoderOnly_Model::CreateState(RoamingArray sequence_lengths, const GeneratorParams& params) const { - return std::make_unique(*this, sequence_lengths, params); +std::unique_ptr DecoderOnly_Model::CreateState(RoamingArray sequence_lengths_unk, const GeneratorParams& params) const { + return std::make_unique(*this, sequence_lengths_unk, params); } DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArray sequence_lengths_unk, const GeneratorParams& params) @@ -25,22 +25,26 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArra extra_inputs_.Add(); } -RoamingArray DecoderOnly_State::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { - if (!first_run_) { - UpdateInputsOutputs(next_tokens, next_indices, current_length); - } +RoamingArray DecoderOnly_State::Run(int total_length, RoamingArray next_tokens, RoamingArray next_indices) { + UpdateInputsOutputs(next_tokens, next_indices, total_length); int batch_size = static_cast(input_ids_.GetShape()[0]); State::Run(*model_.session_decoder_, *model_.run_options_, batch_size); - + return logits_.Get(); } -void DecoderOnly_State::UpdateInputsOutputs(const RoamingArray& next_tokens_unk, RoamingArray beam_indices, int current_length) { - input_ids_.Update(next_tokens_unk); - position_inputs_.Update(current_length); - kv_cache_.Update(beam_indices.GetCPU(), current_length); - logits_.Update(); +void DecoderOnly_State::RewindTo(size_t index) { + position_inputs_.RewindTo(index); + kv_cache_.RewindTo(index); +} + +void DecoderOnly_State::UpdateInputsOutputs(RoamingArray& next_tokens, RoamingArray beam_indices, int total_length) { + input_ids_.Update(next_tokens); + size_t new_length = input_ids_.GetShape()[1]; + position_inputs_.Update(next_tokens, total_length, new_length); + kv_cache_.Update(beam_indices.GetCPU(), total_length); + logits_.Update(next_tokens, new_length); } } // namespace Generators diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index c51a71b91..c68d2473b 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -11,18 +11,20 @@ namespace Generators { struct DecoderOnly_Model : Model { DecoderOnly_Model(std::unique_ptr config, OrtEnv& ort_env); - std::unique_ptr CreateState(RoamingArray sequence_lengths, const GeneratorParams& params) const override; + std::unique_ptr CreateState(RoamingArray sequence_lengths_unk, const GeneratorParams& params) const override; std::unique_ptr session_decoder_; }; struct DecoderOnly_State : State { - DecoderOnly_State(const DecoderOnly_Model& model, RoamingArray sequence_lengths, const GeneratorParams& params); - RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) override; + DecoderOnly_State(const DecoderOnly_Model& model, RoamingArray sequence_lengths_unk, const GeneratorParams& params); + RoamingArray Run(int total_length, RoamingArray next_tokens, RoamingArray next_indices) override; const CapturedGraphInfo* GetCapturedGraphInfo() const override { return captured_graph_info_.get(); }; - private: - void UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray next_indices, int current_length); + void RewindTo(size_t index) override; + + protected: + void UpdateInputsOutputs(RoamingArray& next_tokens, RoamingArray next_indices, int current_length); const DecoderOnly_Model& model_; CapturedGraphInfoPtr captured_graph_info_; diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index 1b2e92093..28fe5d2c3 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -79,7 +79,7 @@ bool IntermediatePipelineState::SupportsPrimaryDevice() const { return false; } -RoamingArray IntermediatePipelineState::Run(int current_length, RoamingArray next_tokens, +RoamingArray IntermediatePipelineState::Run(int total_length, RoamingArray next_tokens, RoamingArray next_indices) { State::Run(*model_.sessions_[id_], *model_.run_options_, params_->BatchBeamSize()); @@ -106,10 +106,10 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode } } -RoamingArray DecoderOnlyPipelineState::Run(int current_length, RoamingArray next_tokens, +RoamingArray DecoderOnlyPipelineState::Run(int total_length, RoamingArray next_tokens, RoamingArray next_indices) { if (!first_run_) { - UpdateInputsOutputs(next_tokens, next_indices, current_length); + UpdateInputsOutputs(next_tokens, next_indices, total_length); } for (auto& pipeline_state : pipeline_states_) { @@ -202,7 +202,7 @@ RoamingArray DecoderOnlyPipelineState::Run(int current_length, RoamingArr } // Run the intermediate pipeline state - pipeline_state->Run(current_length, next_tokens, next_indices); + pipeline_state->Run(total_length, next_tokens, next_indices); // Transfer ownership of all the non-managed outputs from the current pipeline state to the ortvalue store. // All non managed outputs are assumed to be on CPU @@ -239,12 +239,13 @@ RoamingArray DecoderOnlyPipelineState::Run(int current_length, RoamingArr return logits_.Get(); } -void DecoderOnlyPipelineState::UpdateInputsOutputs(const RoamingArray& next_tokens_unk, - RoamingArray beam_indices, int current_length) { - input_ids_.Update(next_tokens_unk); - position_inputs_.Update(current_length); - if (kv_cache_) kv_cache_->Update(beam_indices.GetCPU(), current_length); - logits_.Update(); +void DecoderOnlyPipelineState::UpdateInputsOutputs(const RoamingArray& next_tokens, + RoamingArray beam_indices, int total_length) { + input_ids_.Update(next_tokens); + size_t new_length = input_ids_.GetShape()[1]; + position_inputs_.Update(next_tokens, total_length, new_length); + if (kv_cache_) kv_cache_->Update(beam_indices.GetCPU(), total_length); + logits_.Update(next_tokens, new_length); } OrtValue* DecoderOnlyPipelineState::GetOutput(const char* name) { diff --git a/src/models/decoder_only_pipeline.h b/src/models/decoder_only_pipeline.h index c47022655..8785ef4d6 100644 --- a/src/models/decoder_only_pipeline.h +++ b/src/models/decoder_only_pipeline.h @@ -53,14 +53,14 @@ struct DecoderOnlyPipelineState : State { DecoderOnlyPipelineState(const DecoderOnlyPipelineState&) = delete; DecoderOnlyPipelineState& operator=(const DecoderOnlyPipelineState&) = delete; - RoamingArray Run(int current_length, RoamingArray next_tokens, + RoamingArray Run(int total_length, RoamingArray next_tokens, RoamingArray next_indices) override; OrtValue* GetOutput(const char* name) override; private: void UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray next_indices, - int current_length); + int total_length); const DecoderOnlyPipelineModel& model_; std::vector> pipeline_states_; diff --git a/src/models/embeddings.cpp b/src/models/embeddings.cpp index baaaed345..13ea0f7a9 100644 --- a/src/models/embeddings.cpp +++ b/src/models/embeddings.cpp @@ -9,8 +9,8 @@ namespace Generators { Embeddings::Embeddings(State& state, Embeddings::Mode mode, const std::string& name) : state_{state}, - shape_{static_cast(state_.params_->batch_size) * state_.params_->search.num_beams, - state_.params_->sequence_length, model_.config_->model.decoder.hidden_size}, + shape_{static_cast(state_.params_->search.batch_size) * state_.params_->search.num_beams, + 0, model_.config_->model.decoder.hidden_size}, type_{mode == Embeddings::Mode::Input ? model_.session_info_->GetInputDataType(name) : model_.session_info_->GetOutputDataType(name)}, diff --git a/src/models/gpt.cpp b/src/models/gpt.cpp index 275708038..1c2211ca9 100644 --- a/src/models/gpt.cpp +++ b/src/models/gpt.cpp @@ -35,11 +35,12 @@ RoamingArray Gpt_State::Run(int current_length, RoamingArray nex return logits_.Get(); } -void Gpt_State::UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray beam_indices, int current_length) { - input_ids_.Update(next_tokens); - position_inputs_.Update(current_length); - kv_cache_.Update(beam_indices.GetCPU(), current_length); - logits_.Update(); +void Gpt_State::UpdateInputsOutputs(RoamingArray& next_tokens_unk, RoamingArray beam_indices, int total_length) { + input_ids_.Update(next_tokens_unk); + size_t new_length = input_ids_.GetShape()[1]; + position_inputs_.Update(next_tokens_unk, total_length, new_length); + kv_cache_.Update(beam_indices.GetCPU(), total_length); + logits_.Update(next_tokens_unk, new_length); } } // namespace Generators diff --git a/src/models/gpt.h b/src/models/gpt.h index d58351687..a0658d5f6 100644 --- a/src/models/gpt.h +++ b/src/models/gpt.h @@ -21,7 +21,7 @@ struct Gpt_State : State { RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) override; private: - void UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray beam_indices, int current_length); + void UpdateInputsOutputs(RoamingArray& next_tokens, RoamingArray beam_indices, int current_length); const Gpt_Model& model_; diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 9daa0a628..d4b464747 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -8,25 +8,9 @@ namespace Generators { InputIDs::InputIDs(State& state) : state_{state} { name_ = model_.config_->model.decoder.inputs.input_ids.c_str(); - shape_ = {state_.params_->batch_size, state_.params_->sequence_length}; + shape_ = {state_.params_->BatchBeamSize(), 0}; type_ = model_.session_info_->GetInputDataType(name_); - // If 64-bit, convert from 32-bit to 64-bit - if (type_ == Ort::TypeToTensorType) { - value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_); - auto* p_data = value_->GetTensorMutableData(); - for (auto v : state_.params_->input_ids) { - *p_data++ = v; - } - } else { - if (type_ != Ort::TypeToTensorType) - throw std::runtime_error("InputIDs must be int64 or int32"); - value_ = OrtValue::CreateTensor(model_.allocator_cpu_.GetInfo(), std::span(const_cast(state_.params_->input_ids.data()), shape_[0] * shape_[1]), shape_); - } - - value_ = model_.ExpandInputs(value_, state_.params_->search.num_beams); - shape_[0] *= state_.params_->search.num_beams; - if (state_.GetCapturedGraphInfo()) { sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get(); @@ -37,24 +21,11 @@ InputIDs::InputIDs(State& state) #endif } - const auto get_unpadded_sequence_length = [](std::span input_ids, - int32_t pad_token_id) { - int32_t seq_length = 0; - for (int32_t i = 0; i < input_ids.size(); i++) { - if (input_ids[i] == pad_token_id) { - break; - } - seq_length++; - } - return seq_length; - }; - if (model_.session_info_->HasInput(model_.config_->model.decoder.inputs.current_sequence_length) && model_.session_info_->HasInput(model_.config_->model.decoder.inputs.past_sequence_length)) { if (state_.params_->BatchBeamSize() != 1) { throw std::runtime_error("Batch size must be 1 for current_sequence_length and past_sequence_length inputs"); } - const int32_t current_sequence_length = get_unpadded_sequence_length(state_.params_->input_ids, model_.config_->model.pad_token_id); const std::array current_sequence_length_shape{1}; const std::array past_sequence_length_shape{1, 1}; @@ -63,10 +34,10 @@ InputIDs::InputIDs(State& state) throw std::runtime_error("current_sequence_length and past_sequence_length must be int32"); current_sequence_length_ = OrtValue::CreateTensor(model_.allocator_cpu_, current_sequence_length_shape, model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.current_sequence_length)); - *current_sequence_length_->GetTensorMutableData() = current_sequence_length; + *current_sequence_length_->GetTensorMutableData() = 0; past_sequence_length_ = OrtValue::CreateTensor(*model_.allocator_device_, past_sequence_length_shape, model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.past_sequence_length)); - *past_sequence_length_->GetTensorMutableData() = current_sequence_length - 1; + *past_sequence_length_->GetTensorMutableData() = -1; } } @@ -84,10 +55,32 @@ void InputIDs::Add() { } } -void InputIDs::Update(RoamingArray next_tokens_unk) { - // Resize input_ids shape once if it doesn't match the decoder shape - if (shape_[1] != 1) { - shape_[1] = 1; +void InputIDs::Update(RoamingArray new_tokens) { + const auto get_unpadded_sequence_length = [](std::span input_ids, + int32_t pad_token_id) { + int32_t seq_length = 0; + for (int32_t i = 0; i < input_ids.size(); i++) { + if (input_ids[i] == pad_token_id) { + break; + } + seq_length++; + } + return seq_length; + }; + + if (current_sequence_length_ && past_sequence_length_) { + if (state_.params_->BatchBeamSize() != 1) { + throw std::runtime_error("Batch size must be 1 for current_sequence_length and past_sequence_length inputs"); + } + auto new_sequence_length = get_unpadded_sequence_length(new_tokens.GetCPU(), model_.config_->model.pad_token_id); + *current_sequence_length_->GetTensorMutableData() += new_sequence_length; + *past_sequence_length_->GetTensorMutableData() += new_sequence_length; + } + + // Resize input_ids shape based on new_tokens + size_t sequence_length = static_cast(new_tokens.GetCPU().size()) / shape_[0]; + if (shape_[1] != sequence_length) { + shape_[1] = sequence_length; if (!sb_input_ids_) { value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); @@ -115,7 +108,7 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { case DeviceType::CUDA: { #if USE_CUDA auto* data = value_->GetTensorMutableData(); - auto next_tokens = next_tokens_unk.GetGPU(); + auto next_tokens = new_tokens.GetGPU(); cuda::LaunchInt32ToInt64(next_tokens.data(), data, static_cast(next_tokens.size()), model_.cuda_stream_); #endif } break; @@ -126,8 +119,8 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, value_int32_->GetTensorMutableRawData(), &source_resource)); auto source = std::span( - reinterpret_cast(next_tokens_unk.GetCPU().data()), - next_tokens_unk.GetCPU().size_bytes()); + reinterpret_cast(new_tokens.GetCPU().data()), + new_tokens.GetCPU().size_bytes()); model_.GetDmlUploadHeap()->BeginUploadToGpu( source_resource.Get(), @@ -147,9 +140,11 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { } break; case DeviceType::CPU: { auto* data = value_->GetTensorMutableData(); - auto next_tokens = next_tokens_unk.GetCPU(); - for (int i = 0; i < shape_[0]; i++) { - data[i] = next_tokens[i]; + auto next_tokens = new_tokens.GetCPU(); + for (int b = 0; b < shape_[0]; b++) { + for (int i = 0; i < shape_[1]; i++) { + data[b * shape_[1] + i] = next_tokens[b * shape_[1] + i]; + } } } } @@ -157,15 +152,10 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { auto* data = value_->GetTensorMutableData(); #if USE_CUDA if (model_.device_type_ == DeviceType::CUDA) - cudaMemcpyAsync(data, next_tokens_unk.GetGPU().data(), shape_[0] * sizeof(int32_t), cudaMemcpyDeviceToDevice, model_.cuda_stream_); + cudaMemcpyAsync(data, new_tokens.GetGPU().data(), shape_[0] * shape_[1] * sizeof(int32_t), cudaMemcpyDeviceToDevice, model_.cuda_stream_); else #endif - memcpy(data, next_tokens_unk.GetCPU().data(), shape_[0] * sizeof(int32_t)); - } - - if (current_sequence_length_ && past_sequence_length_) { - *current_sequence_length_->GetTensorMutableData() += 1; - *past_sequence_length_->GetTensorMutableData() += 1; + memcpy(data, new_tokens.GetCPU().data(), shape_[0] * shape_[1] * sizeof(int32_t)); } } diff --git a/src/models/input_ids.h b/src/models/input_ids.h index 5d61fb7bd..99e3f1116 100644 --- a/src/models/input_ids.h +++ b/src/models/input_ids.h @@ -9,7 +9,11 @@ struct InputIDs { InputIDs(const InputIDs&) = delete; InputIDs& operator=(const InputIDs&) = delete; + // Register input_ids as ORT session input. + // Called only once during initialization of state. void Add(); + // Resize input_ids based on size of next_tokens. + // Update value with next_tokens. void Update(RoamingArray next_tokens); auto& GetShape() const { return shape_; } diff --git a/src/models/kernels.cu b/src/models/kernels.cu index c93e20497..647add9dc 100644 --- a/src/models/kernels.cu +++ b/src/models/kernels.cu @@ -25,6 +25,24 @@ void Launch_UpdatePositionIds(T* positions, int batch_beam_size, cudaStream_t st template void Launch_UpdatePositionIds(int32_t* positions, int batch_beam_size, cudaStream_t stream); template void Launch_UpdatePositionIds(int64_t* positions, int batch_beam_size, cudaStream_t stream); +template +__global__ void UpdatePositionIds(T* positions, int total_length, int new_kv_length) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < new_kv_length) { + positions[i] = total_length + i; + } +} + +template +void Launch_UpdatePositionIds(T* positions, int total_length, int new_kv_length, cudaStream_t stream) { + int threads = std::min(256, new_kv_length); + int blocks = (new_kv_length + threads - 1) / threads; + UpdatePositionIds<<>>(positions, total_length, new_kv_length); +} + +template void Launch_UpdatePositionIds(int32_t* positions, int total_length, int new_kv_length, cudaStream_t stream); +template void Launch_UpdatePositionIds(int64_t* positions, int total_length, int new_kv_length, cudaStream_t stream); + template __global__ void CopyAndUpdateAttentionMask(T* mask_data, const T* old_mask_data, int batch_beam_size, int current_length, int max_length) { @@ -65,6 +83,39 @@ template void Launch_UpdateAttentionMask(int32_t* mask_data, const int32_t* old_ template void Launch_UpdateAttentionMask(int64_t* mask_data, const int64_t* old_mask_data, int batch_beam_size, int current_length, int max_length, bool update_only, cudaStream_t stream); +template +__global__ void UpdateAttentionMaskStatic(T* mask_data, int new_kv_length, int total_length) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int past_length = total_length - new_kv_length; + if (i < new_kv_length) { + mask_data[past_length + i] = 1; + } +} + +template +__global__ void UpdateAttentionMask(T* mask_data, int total_length) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < total_length) { + mask_data[i] = 1; + } +} + +template +void Launch_UpdateAttentionMask(T* mask_data, int new_kv_length , int total_length, bool update_static, cudaStream_t stream) { + if (update_static) { + int threads = std::min(256, new_kv_length); + int blocks = (new_kv_length + threads - 1) / threads; + UpdateAttentionMaskStatic<<>>(mask_data, new_kv_length, total_length); + } else { + int threads = std::min(256, total_length); + int blocks = (total_length + threads - 1) / threads; + UpdateAttentionMask<<>>(mask_data, total_length); + } +} + +template void Launch_UpdateAttentionMask(int32_t* mask_data, int new_kv_length , int total_length, bool update_static, cudaStream_t stream); +template void Launch_UpdateAttentionMask(int64_t* mask_data, int new_kv_length , int total_length, bool update_static, cudaStream_t stream); + __global__ void HandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index >= batch_beam_size) diff --git a/src/models/kernels.h b/src/models/kernels.h index 442172696..ee79eac62 100644 --- a/src/models/kernels.h +++ b/src/models/kernels.h @@ -8,8 +8,12 @@ namespace cuda { template void Launch_UpdatePositionIds(T* positions, int batch_beam_size, cudaStream_t stream); template +void Launch_UpdatePositionIds(T* positions, int total_length, int new_kv_length, cudaStream_t stream); +template void Launch_UpdateAttentionMask(T* mask_data, const T* old_mask_data, int batch_beam_size, int current_length, int max_length, bool update_only, cudaStream_t stream); +template +void Launch_UpdateAttentionMask(T* mask_data, int new_kv_length , int total_length, bool update_static, cudaStream_t stream); void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream); diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 35f47ef26..c125eafe7 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -35,7 +35,7 @@ KV_Cache_Combined::KV_Cache_Combined(State& state) type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]); empty_past_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); - shape_[3] = state_.params_->sequence_length; + shape_[3] = 0; for (int i = 0; i < layer_count_; ++i) { presents_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_)); @@ -152,23 +152,21 @@ KV_Cache::KV_Cache(State& state) empty_past_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); // Set the size after empty_past_ has been created with 0 for this field - if (past_present_share_buffer_) + if (past_present_share_buffer_) { shape_[2] = state_.params_->search.max_length; - else - shape_[2] = state_.params_->sequence_length; - if (state_.GetCapturedGraphInfo()) { - assert(past_present_share_buffer_); - sb_kv_caches_.reserve(layer_count_ * 2); - for (int i = 0; i < layer_count_ * 2; ++i) { - sb_kv_caches_.push_back(state_.GetCapturedGraphInfo()->sb_kv_caches_[i].get()); + if (state_.GetCapturedGraphInfo()) { + sb_kv_caches_.reserve(layer_count_ * 2); + for (int i = 0; i < layer_count_ * 2; ++i) { + sb_kv_caches_.push_back(state_.GetCapturedGraphInfo()->sb_kv_caches_[i].get()); + } } - } - for (int i = 0; i < layer_count_ * 2; ++i) { - presents_.push_back( - sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) - : sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_)); + for (int i = 0; i < layer_count_ * 2; ++i) { + presents_.push_back( + sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) + : sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_)); + } } } @@ -200,25 +198,79 @@ void KV_Cache::Add() { } } -void KV_Cache::Update(std::span beam_indices, int current_length) { +// TODO(aciddelgado): consider 0-initializing pasts somewhere +void KV_Cache::Update(std::span beam_indices, int total_length) { // If we're sharing past & present buffers there is nothing to do here, so early exit if (past_present_share_buffer_) return; - for (int i = 0; i < layer_count_ * 2; i++) { - if (beam_indices.empty()) { - pasts_[i] = std::move(presents_[i]); - } else { - PickPastState(beam_indices, i); + if (!is_first_update_) { + for (int i = 0; i < layer_count_ * 2; i++) { + if (beam_indices.empty()) { + pasts_[i] = std::move(presents_[i]); + } else { + PickPastState(beam_indices, i); + } + state_.inputs_[input_index_ + i] = pasts_[i].get(); } - state_.inputs_[input_index_ + i] = pasts_[i].get(); } - shape_[2] = current_length; + shape_[2] = total_length; for (int i = 0; i < layer_count_ * 2; i++) { presents_[i] = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); state_.outputs_[output_index_ + i] = presents_[i].get(); } + + is_first_update_ = false; +} + +// TODO(aciddelgado): test with past_present_share_buffer_ = false +void KV_Cache::RewindTo(size_t index) { + if (past_present_share_buffer_) { + return; + } else if (shape_[2] <= static_cast(index)) { + throw std::runtime_error("Requested length of rewind is greater than the current length."); + } + + is_first_update_ = true; + if (index == 0) { + for (int i = 0; i < layer_count_ * 2; i++) { + pasts_[i] = nullptr; + } + } else if (type_ == Ort::TypeToTensorType) { + RewindPastTensorsTo(index); + } else { + RewindPastTensorsTo(index); + } +} + +template +void KV_Cache::RewindPastTensorsTo(size_t index) { + assert(index > 0 && shape_[2] >= index && !past_present_share_buffer_); + std::array new_shape = shape_; + new_shape[2] = static_cast(index); + auto batch_x_num_heads = new_shape[0] * new_shape[1]; + auto new_length_x_head_size = new_shape[2] * new_shape[3]; + auto old_length_x_head_size = shape_[2] * new_shape[3]; + + for (int i = 0; i < layer_count_ * 2; i++) { + OrtValue& present = *presents_[i]; + std::unique_ptr past = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + for (int j = 0; j < batch_x_num_heads; j++) { + auto present_data = present.GetTensorData() + j * old_length_x_head_size; + auto past_data = past->GetTensorMutableData() + j * new_length_x_head_size; +#if USE_CUDA + if (model_.device_type_ == DeviceType::CUDA) { + cudaMemcpyAsync(past_data, present_data, new_length_x_head_size * sizeof(T), cudaMemcpyDeviceToDevice, model_.cuda_stream_); + } else +#endif + { + copy(std::span(present_data, new_length_x_head_size), std::span(past_data, new_length_x_head_size)); + } + } + pasts_[i] = std::move(past); + state_.inputs_[input_index_ + i] = pasts_[i].get(); + } } // Copy present state to past state reordered by the beam_indices diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index 7ffe104d4..957c7b205 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -34,19 +34,28 @@ struct KV_Cache { static bool IsCacheNeeded(const Model& model); void AddEncoder(); // If model has an initial encoder step, this is used + // Register input_ids as ORT session input. + // Called only once during initialization of state. void Add(); - void Update(std::span beam_indices, int current_length); + // Move present to past. Prepare present output for next generation iteration. + void Update(std::span beam_indices, int total_length); + void RewindTo(size_t index); template void PickPastState(std::span beam_indices, int index); void PickPastState(std::span beam_indices, int index); private: + template + void RewindPastTensorsTo(size_t index); + State& state_; const Model& model_{state_.model_}; int layer_count_; size_t input_index_{~0U}, output_index_{~0U}; bool past_present_share_buffer_; // True if model.decoder.past_present_share_buffer is set to true, and we're using cuda, and not beam search + bool is_first_update_{true}; + std::array shape_; ONNXTensorElementDataType type_; diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 0e333e150..da1d9def2 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -11,7 +11,7 @@ namespace Generators { Logits::Logits(State& state) : state_{state}, - shape_{static_cast(state_.params_->batch_size) * state_.params_->search.num_beams, state_.params_->sequence_length, model_.config_->model.vocab_size}, + shape_{static_cast(state_.params_->BatchBeamSize()), 0, model_.config_->model.vocab_size}, type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} { output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); @@ -31,6 +31,8 @@ Logits::Logits(State& state) cudaMemcpyAsync(cuda_eos_token_ids_.data(), cpu_ids.data(), cpu_ids.size() * sizeof(int32_t), ::cudaMemcpyHostToDevice, model_.cuda_stream_); } #endif + + input_sequence_lengths.resize(state_.params_->search.batch_size); } #pragma warning(push) @@ -39,19 +41,17 @@ Logits::Logits(State& state) RoamingArray Logits::Get() { size_t element_count = shape_[0] * shape_[1] * shape_[2]; - // First iteration? Then copy the logits over to a {batch_beams, 1, vocab_size} tensor // The model's output logits are {batch_size*num_beams, input_seq_len, vocab_size} OrtValue* logits_of_last_token = output_raw_.get(); + std::array shape_last{shape_[0], 1, shape_[2]}; if (shape_[1] != 1) { const size_t seq_length = shape_[1]; const size_t vocab_size = shape_[2]; const size_t num_beams = state_.params_->search.num_beams; const size_t element_count_last_token = shape_[0] * shape_[2]; - shape_[1] = 1; - // create new OrtValue for logits_of_last_token and use output_last_tokens_ to hold it - output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_last, type_); #if USE_DML if (type_ == Ort::TypeToTensorType) { @@ -64,15 +64,9 @@ RoamingArray Logits::Get() { size_t element_size = type_ == Ort::TypeToTensorType ? 4 : 2; size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process - const auto* input_ids = state_.params_->input_ids.data(); - for (int batch_index = 0; batch_index < state_.params_->batch_size; batch_index++) { + for (int batch_index = 0; batch_index < state_.params_->search.batch_size; batch_index++) { // Find the first non pad token from the end - size_t token_index = seq_length; - while (token_index-- > 0) { - if (input_ids[token_index] != model_.config_->model.pad_token_id) - break; - } - + size_t token_index = input_sequence_lengths[batch_index] - 1; for (int beam_index = 0; beam_index < num_beams; beam_index++) { switch (model_.device_type_) { case DeviceType::DML: { @@ -117,8 +111,6 @@ RoamingArray Logits::Get() { vocab_index += vocab_size; } - - input_ids += seq_length; } element_count = shape_[0] * shape_[2]; // shape_[1] is now 1, so the element count must be updated @@ -151,7 +143,7 @@ RoamingArray Logits::Get() { #if USE_DML // DML doesn't support on-device scoring yet, so we need to download some data to the CPU if (model_.device_type_ == DeviceType::DML) { - value32_cpu_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_); + value32_cpu_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_last); } #endif @@ -199,11 +191,24 @@ RoamingArray Logits::Get() { #pragma warning(pop) -void Logits::Update() { - if (output_raw_.get()->GetTensorTypeAndShapeInfo()->GetShape()[1] == 1) { +void Logits::Update(const RoamingArray& next_tokens, int new_kv_length) { + if (output_raw_.get()->GetTensorTypeAndShapeInfo()->GetShape()[1] == new_kv_length) { return; } + // Store length of input sequence for each batch for the get step + for (int b = 0; b < state_.params_->search.batch_size; b++) { + // Find the first non pad token from the end + size_t token_index = new_kv_length; + while (token_index-- > 0) { + auto next_token = const_cast&>(next_tokens).GetCPU()[b * new_kv_length + token_index]; + if (next_token != model_.config_->model.pad_token_id) + break; + } + input_sequence_lengths[b] = token_index + 1; + } + + shape_[1] = new_kv_length; StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType ? sb_logits16_ : sb_logits32_; output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) : sb_logits->CreateTensorOnStaticBuffer(shape_, type_); diff --git a/src/models/logits.h b/src/models/logits.h index 49b3a827f..bb3ee02c1 100644 --- a/src/models/logits.h +++ b/src/models/logits.h @@ -9,10 +9,13 @@ namespace Generators { struct Logits { Logits(State& state); + // Register input_ids as ORT session input. void Add(); + // For first iteration, find last token of each beam and store it in output_last_tokens_. RoamingArray Get(); - void Update(); + // Resize logits to [bz, token_count, vocab_size] if necessary. + void Update(const RoamingArray& next_tokens, int new_kv_length); private: void HandleEOSArray(cpu_span logits); @@ -31,6 +34,8 @@ struct Logits { std::unique_ptr output_raw_; // Raw logits output from model + std::vector input_sequence_lengths; + // Used for decoding runs with cuda graphs. StaticBuffer* sb_logits32_{}; StaticBuffer* sb_logits16_{}; diff --git a/src/models/model.h b/src/models/model.h index b7b7bdfb2..65698e9ac 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -29,12 +29,14 @@ struct State { State(const GeneratorParams& params, const Model& model_); virtual ~State() = default; - virtual RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices = {}) = 0; + virtual RoamingArray Run(int total_length, RoamingArray next_tokens, RoamingArray next_indices = {}) = 0; virtual const CapturedGraphInfo* GetCapturedGraphInfo() const { return nullptr; } virtual void Finalize() {} OrtValue* GetInput(const char* name); + virtual void RewindTo(size_t index) { (void)index; }; + virtual OrtValue* GetOutput(const char* name); void ClearIO(); // Clear all inputs/outputs diff --git a/src/models/multi_modal_vision_model.cpp b/src/models/multi_modal_vision_model.cpp index 7bb95634a..5ff1fc3d3 100644 --- a/src/models/multi_modal_vision_model.cpp +++ b/src/models/multi_modal_vision_model.cpp @@ -4,6 +4,8 @@ #include "../generators.h" #include "multi_modal_vision_model.h" +// TODO(aciddelgado): update to use new input logic + namespace Generators { namespace { @@ -134,10 +136,12 @@ RoamingArray DecoderState::Run(int current_length, RoamingArray return logits_.Get(); } -void DecoderState::UpdateInputsOutputs(int current_length, RoamingArray beam_indices) { - position_inputs_.Update(current_length); - kv_cache_.Update(beam_indices.GetCPU(), current_length); - logits_.Update(); +void DecoderState::UpdateInputsOutputs(RoamingArray next_tokens, int total_length, RoamingArray beam_indices) { + int batch_size = static_cast(inputs_embeds_.GetShape()[0]); + size_t new_length = next_tokens.GetCPU().size() / batch_size; + position_inputs_.Update(next_tokens, total_length, new_length); + kv_cache_.Update(beam_indices.GetCPU(), total_length); + logits_.Update(next_tokens, new_length); } MultiModalPipelineState::MultiModalPipelineState(const MultiModalVisionModel& model, @@ -179,7 +183,7 @@ RoamingArray MultiModalPipelineState::Run(int current_length, RoamingArra } embedding_state_->UpdateInputsAndOutputs(next_tokens); - decoder_state_->UpdateInputsOutputs(current_length, next_indices); + decoder_state_->UpdateInputsOutputs(next_tokens, current_length, next_indices); embedding_state_->Run(current_length, next_tokens, next_indices); decoder_state_->inputs_embeds_.ReuseEmbeddingsBuffer(embedding_state_->inputs_embeds_); diff --git a/src/models/multi_modal_vision_model.h b/src/models/multi_modal_vision_model.h index fef1a9f36..f5393f2be 100644 --- a/src/models/multi_modal_vision_model.h +++ b/src/models/multi_modal_vision_model.h @@ -86,7 +86,7 @@ struct DecoderState : State { private: friend struct MultiModalPipelineState; - void UpdateInputsOutputs(int current_length, RoamingArray beam_indices); + void UpdateInputsOutputs(RoamingArray next_tokens, int current_length, RoamingArray beam_indices); const MultiModalVisionModel& model_; const CapturedGraphInfo* captured_graph_info_; diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index 2666afc17..e996be0ec 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -31,22 +31,13 @@ PositionInputs::PositionInputs(const Model& model, State& state, RoamingArray && type_ != Ort::TypeToTensorType) throw std::runtime_error("position_ids & attention_mask only support int32 or int64 types"); - std::array shape{state_.params_->batch_size, state_.params_->sequence_length}; // Only batch_size initially, as we haven't expanded over the beams yet - position_ids_ = OrtValue::CreateTensor(model.allocator_cpu_, shape, type_); - position_ids_next_ = OrtValue::CreateTensor(model.allocator_cpu_, std::array{shape[0], 1}, type_); - attention_mask_ = OrtValue::CreateTensor(model.allocator_cpu_, shape, type_); - - initial_sequence_lengths_.resize(state_.params_->BatchBeamSize()); + std::array shape{state_.params_->search.batch_size, 0}; // Only batch_size initially, as we haven't expanded over the beams yet if (type_ == Ort::TypeToTensorType) - InitializeTensors(shape, sequence_lengths_unk); + InitializeSequenceLengths(shape, sequence_lengths_unk); else - InitializeTensors(shape, sequence_lengths_unk); + InitializeSequenceLengths(shape, sequence_lengths_unk); - position_ids_ = model_.ExpandInputs(position_ids_, state_.params_->search.num_beams); - position_ids_next_ = model_.ExpandInputs(position_ids_next_, state_.params_->search.num_beams); - attention_mask_ = model_.ExpandInputs(attention_mask_, state_.params_->search.num_beams); - shape[0] *= state_.params_->search.num_beams; position_ids_shape_ = shape; attention_mask_shape_ = shape; @@ -75,12 +66,58 @@ void PositionInputs::Add() { } } -void PositionInputs::Update(int current_length) { +void PositionInputs::Update(const RoamingArray& next_tokens, int total_length, int new_length) { if (has_posid_input_) { - UpdatePositionIDs(current_length); + // Initialize on first update + if (is_first_update_) { + position_ids_shape_[1] = new_length; + if (type_ == Ort::TypeToTensorType) + CreateAndInitializePositionIDs(next_tokens, position_ids_shape_); + else + CreateAndInitializePositionIDs(next_tokens, position_ids_shape_); + } else { + // Batch size > 1 case + if (position_ids_shape_[0] > 1) + UpdatePositionIDs(); + // Batch size = 1 case (continuous decoding) + else + UpdatePositionIDs(total_length, new_length); + } } if (has_mask_input_) { - UpdateAttentionMask(current_length); + // Initialize on first update + if (is_first_update_) { + attention_mask_shape_[1] = new_length; + if (type_ == Ort::TypeToTensorType) + CreateAndInitializeAttentionMask(next_tokens, attention_mask_shape_); + else + CreateAndInitializeAttentionMask(next_tokens, attention_mask_shape_); + } else { + // Batch size > 1 case + if (attention_mask_shape_[0] > 1) + UpdateAttentionMask(total_length); + // Batch size = 1 case + else + UpdateAttentionMask(total_length, new_length); + } + } + is_first_update_ = false; +} + +void PositionInputs::RewindTo(size_t index) { + // Reset the state of the position inputs + if (index == 0) { + is_first_update_ = true; + is_first_posid_update_ = true; + is_first_mask_update_ = true; + // Rewind the mask input to a previous state + } else if (has_mask_input_) { + if (attention_mask_shape_[0] == 1) +#if USE_CUDA + RewindMask(index); + else +#endif + throw std::runtime_error("PositionInputs::RewindTo - Unsupported batch size"); } } @@ -98,7 +135,7 @@ void PositionInputs::AddPositionIDs() { state_.input_names_.push_back(model_.config_->model.decoder.inputs.position_ids.c_str()); } -void PositionInputs::UpdatePositionIDs(int current_length) { +void PositionInputs::UpdatePositionIDs() { // Reallocate position_ids for the 2nd and onward shape if (is_first_posid_update_) { position_ids_shape_[1] = 1; @@ -193,7 +230,50 @@ void PositionInputs::UpdatePositionIDs(int current_length) { } } -void PositionInputs::UpdateAttentionMask(int current_length) { +void PositionInputs::UpdatePositionIDs(int total_length, int new_kv_length) { + // Support batch_size == 1 only with current length > 0 and new kv length > 1 + if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) + throw std::runtime_error("PositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); + // Doesn't support DML at the moment + if (model_.device_type_ == DeviceType::DML) + throw std::runtime_error("PositionInputs::UpdatePositionIDs - DML not supported for continuous decoding."); + // Reallocate position_ids when new_kv_length changes + if (position_ids_shape_[1] != new_kv_length) { + position_ids_shape_[1] = new_kv_length; + if (!sb_position_ids_) { + position_ids_ = OrtValue::CreateTensor(model_.allocator_cpu_, position_ids_shape_, type_); + } else { +#if USE_CUDA + position_ids_ = sb_position_ids_->CreateTensorOnStaticBuffer(position_ids_shape_, type_); + assert(model_.device_type_ == DeviceType::CUDA); +#endif + } + state_.inputs_[posid_input_index_] = position_ids_.get(); + } + is_first_posid_update_ = false; + // Just incrementing existing position IDs + switch (model_.device_type_) { + case DeviceType::CPU: { + if (type_ == Ort::TypeToTensorType) + UpdatePositionIDsImpl(total_length, new_kv_length); + else + UpdatePositionIDsImpl(total_length, new_kv_length); + break; + } +#if USE_CUDA + case DeviceType::CUDA: + if (type_ == Ort::TypeToTensorType) + cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData(), total_length, new_kv_length, model_.cuda_stream_); + else + cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData(), total_length, new_kv_length, model_.cuda_stream_); + break; +#endif + default: + throw std::runtime_error("PositionIDs::Update - Unsupported device type"); + } +} + +void PositionInputs::UpdateAttentionMask(int total_length) { // Update attention mask if (sb_attention_mask_) { #if USE_CUDA @@ -218,15 +298,14 @@ void PositionInputs::UpdateAttentionMask(int current_length) { attention_mask_next_ = sb_attention_mask_next_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); #endif } else { - assert(attention_mask_shape_[1] == current_length - 1); // We should always be growing by 1 - attention_mask_shape_[1] = current_length; + assert(attention_mask_shape_[1] == total_length - 1); // We should always be growing by 1 + attention_mask_shape_[1] = total_length; #if USE_DML if (model_.device_type_ == DeviceType::DML) { attention_mask_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); } #endif - attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); } @@ -246,7 +325,7 @@ void PositionInputs::UpdateAttentionMask(int current_length) { static_cast(attention_mask_shape_[0]), static_cast(attention_mask_shape_[1]), type_, - current_length, + total_length, attention_mask_resource.Get(), attention_mask_next_resource.Get()); is_second_mask_update_ = true; @@ -273,22 +352,22 @@ void PositionInputs::UpdateAttentionMask(int current_length) { if (type_ == Ort::TypeToTensorType) UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), attention_mask_->GetTensorData(), - current_length); + total_length); else UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), attention_mask_->GetTensorData(), - current_length); + total_length); break; } #if USE_CUDA case DeviceType::CUDA: { - int max_seq_len = sb_attention_mask_ ? state_.params_->search.max_length : current_length; + int max_seq_len = sb_attention_mask_ ? state_.params_->search.max_length : total_length; bool update_only = sb_attention_mask_ && !is_first_mask_update_; if (type_ == Ort::TypeToTensorType) { cuda::Launch_UpdateAttentionMask(attention_mask_next_->GetTensorMutableData(), attention_mask_->GetTensorData(), static_cast(attention_mask_shape_[0]), - current_length, + total_length, max_seq_len, update_only, model_.cuda_stream_); @@ -296,7 +375,7 @@ void PositionInputs::UpdateAttentionMask(int current_length) { cuda::Launch_UpdateAttentionMask(attention_mask_next_->GetTensorMutableData(), attention_mask_->GetTensorData(), static_cast(attention_mask_shape_[0]), - current_length, + total_length, max_seq_len, update_only, model_.cuda_stream_); @@ -317,38 +396,143 @@ void PositionInputs::UpdateAttentionMask(int current_length) { #endif state_.inputs_[mask_input_index_] = attention_mask_.get(); + is_first_mask_update_ = false; +} + +void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { + // Support batch_size == 1 only with current length > 0 and new kv length > 1 + if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) + throw std::runtime_error("PositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); + // Doesn't support DML at the moment + if (model_.device_type_ == DeviceType::DML) + throw std::runtime_error("PositionInputs::UpdatePositionIDs - DML not supported for continuous decoding."); + // Update attention mask + if (sb_attention_mask_ && is_first_mask_update_) { +#if USE_CUDA + int past_length = total_length - new_kv_length; + int max_length = state_.params_->search.max_length; + attention_mask_shape_[1] = max_length; + attention_mask_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); + if (type_ == Ort::TypeToTensorType) { + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), + 1, + sizeof(int32_t) * past_length, + model_.cuda_stream_); + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, + 0, + sizeof(int32_t) * (max_length - past_length), + model_.cuda_stream_); + } else { + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), + 1, + sizeof(int64_t) * past_length, + model_.cuda_stream_); + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, + 0, + sizeof(int64_t) * (max_length - past_length), + model_.cuda_stream_); + } +#endif + } else if (!sb_attention_mask_) { + attention_mask_shape_[1] = total_length; + attention_mask_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); + } + + switch (model_.device_type_) { + case DeviceType::CPU: { + if (type_ == Ort::TypeToTensorType) + UpdateAttentionMaskImpl(attention_mask_->GetTensorMutableData(), total_length); + else + UpdateAttentionMaskImpl(attention_mask_->GetTensorMutableData(), total_length); + break; + } +#if USE_CUDA + case DeviceType::CUDA: { + bool update_static = sb_attention_mask_; + if (type_ == Ort::TypeToTensorType) { + cuda::Launch_UpdateAttentionMask(attention_mask_->GetTensorMutableData(), + new_kv_length, + total_length, + update_static, + model_.cuda_stream_); + } else { + cuda::Launch_UpdateAttentionMask(attention_mask_->GetTensorMutableData(), + new_kv_length, + total_length, + update_static, + model_.cuda_stream_); + } + break; + } +#endif + default: + throw std::runtime_error("PositionInputs::Update - Unsupported device type"); + } + state_.inputs_[mask_input_index_] = attention_mask_.get(); is_first_mask_update_ = false; } template -void PositionInputs::InitializeTensors(std::array shape, cpu_span sequence_lengths) { +void PositionInputs::CreateAndInitializePositionIDs(const RoamingArray& next_tokens, std::array shape) { // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens - auto* mask_data = attention_mask_->GetTensorMutableData(); + position_ids_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); + position_ids_next_ = OrtValue::CreateTensor(model_.allocator_cpu_, std::array{shape[0], 1}, type_); auto* position_data = position_ids_->GetTensorMutableData(); auto* position_data_next = position_ids_next_->GetTensorMutableData(); - const auto* word_id = state_.params_->input_ids.data(); - auto* mask = mask_data; + const auto* word_id = const_cast&>(next_tokens).GetCPU().data(); auto* position = position_data; for (int i = 0; i < shape[0]; i++) { T abs_position = 0; - for (int j = 0; j < shape[1]; j++, word_id++, mask++, position++) { + for (int j = 0; j < shape[1]; j++, word_id++, position++) { if (*word_id == model_.config_->model.pad_token_id) { - *mask = 0; *position = 0; } else { - *mask = 1; *position = abs_position++; } } position_data_next[i] = abs_position; - for (int k = 0; k < state_.params_->search.num_beams; k++) { - sequence_lengths[i * state_.params_->search.num_beams + k] = static_cast(abs_position); - initial_sequence_lengths_[i * state_.params_->search.num_beams + k] = static_cast(abs_position); + } + + // Move tensors to appropriate device and expand by num_beams + position_ids_ = model_.ExpandInputs(position_ids_, state_.params_->search.num_beams); + position_ids_next_ = model_.ExpandInputs(position_ids_next_, state_.params_->search.num_beams); + position_ids_shape_[0] *= state_.params_->search.num_beams; + state_.inputs_[posid_input_index_] = position_ids_.get(); +} + +template +void PositionInputs::CreateAndInitializeAttentionMask(const RoamingArray& next_tokens, std::array shape) { + // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. + // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens + attention_mask_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); + auto* mask_data = attention_mask_->GetTensorMutableData(); + const auto* word_id = const_cast&>(next_tokens).GetCPU().data(); + auto* mask = mask_data; + for (int i = 0; i < shape[0]; i++) { + T abs_position = 0; + for (int j = 0; j < shape[1]; j++, word_id++, mask++) { + if (*word_id == model_.config_->model.pad_token_id) { + *mask = 0; + } else { + *mask = 1; + } } } + + // Move tensors to appropriate device and expand by num_beams + attention_mask_ = model_.ExpandInputs(attention_mask_, state_.params_->search.num_beams); + attention_mask_shape_[0] *= state_.params_->search.num_beams; + state_.inputs_[mask_input_index_] = attention_mask_.get(); +} + +template +void PositionInputs::InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk) { + for (int i = 0; i < shape[0] * state_.params_->search.num_beams; i++) { + sequence_lengths_unk[i] = 0; + } } template @@ -360,6 +544,13 @@ void PositionInputs::UpdatePositionIDsImpl() { } }; +template +void PositionInputs::UpdatePositionIDsImpl(int current_length, int new_kv_length) { + auto* data = position_ids_->GetTensorMutableData(); + for (int i = 0; i < new_kv_length; i++) + data[i] = i + current_length + new_kv_length; +}; + template void PositionInputs::UpdateAttentionMaskImpl(T* data, const T* old_data, int current_length) { for (int i = 0; i < attention_mask_shape_[0]; i++) { @@ -370,4 +561,39 @@ void PositionInputs::UpdateAttentionMaskImpl(T* data, const T* old_data, int cur } }; +template +void PositionInputs::UpdateAttentionMaskImpl(T* data, int total_length) { + for (int i = 0; i < total_length; i++) { + data[i] = 1; + } +}; + +#if USE_CUDA +void PositionInputs::RewindMask(size_t index) { + if (sb_attention_mask_ && !is_first_mask_update_) { + int past_length = static_cast(index); + int max_length = static_cast(state_.params_->search.max_length); + if (type_ == Ort::TypeToTensorType) { + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), + 1, + sizeof(int32_t) * past_length, + model_.cuda_stream_); + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, + 0, + sizeof(int32_t) * (max_length - past_length), + model_.cuda_stream_); + } else { + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), + 1, + sizeof(int64_t) * past_length, + model_.cuda_stream_); + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, + 0, + sizeof(int64_t) * (max_length - past_length), + model_.cuda_stream_); + } + } +} +#endif + } // namespace Generators diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h index 259f5c0c2..438e2798d 100644 --- a/src/models/position_inputs.h +++ b/src/models/position_inputs.h @@ -11,25 +11,45 @@ namespace Generators { struct PositionInputs { PositionInputs(const Model& model, State& state, RoamingArray& sequence_lengths); + PositionInputs(const Model& model, State& state); void Add(); - void Update(int current_length); + void Update(const RoamingArray& next_tokens_unk, int total_length, int new_length); + + void RewindTo(size_t index); private: void AddAttentionMask(); void AddPositionIDs(); - void UpdatePositionIDs(int current_length); - void UpdateAttentionMask(int current_length); + // Batch size > 1 case + void UpdatePositionIDs(); + void UpdateAttentionMask(int total_length); + // Batch size == 1 case. + void UpdatePositionIDs(int total_length, int new_length); + void UpdateAttentionMask(int total_length, int new_length); template - void InitializeTensors(std::array shape, cpu_span sequence_lengths); + void InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk); + template + void CreateAndInitializePositionIDs(const RoamingArray& next_tokens, std::array shape); + template + void CreateAndInitializeAttentionMask(const RoamingArray& next_tokens, std::array shape); template void UpdatePositionIDsImpl(); template void UpdateAttentionMaskImpl(T* data, const T* old_data, int current_length); + template + void UpdatePositionIDsImpl(int total_length, int new_kv_length); + template + void UpdateAttentionMaskImpl(T* data, int total_length); + +#if USE_CUDA + void RewindMask(size_t index); +#endif + const Model& model_; State& state_; @@ -41,14 +61,13 @@ struct PositionInputs { bool has_mask_input_{false}; bool has_posid_input_{false}; - std::array position_ids_shape_{}; // {params.batch_size*params.beam_size, params.sequence_length} + std::array position_ids_shape_{}; std::unique_ptr position_ids_; - std::array attention_mask_shape_{}; // {params.batch_size*params.beam_size, params.sequence_length} + std::array attention_mask_shape_{}; std::unique_ptr attention_mask_; std::unique_ptr position_ids_next_; // Replaces position_ids_ after the first Run() call std::unique_ptr attention_mask_next_; // Replaces attention_mask_ after the first Run() call - std::vector initial_sequence_lengths_; // Used for decoding runs with cuda graphs. StaticBuffer* sb_position_ids_{}; @@ -56,6 +75,7 @@ struct PositionInputs { bool is_first_posid_update_{true}; bool is_first_mask_update_{true}; + bool is_first_update_{true}; #if USE_DML std::optional dml_update_mask_kernel_; diff --git a/src/models/whisper.cpp b/src/models/whisper.cpp index 4f7f4bbe1..50a46587c 100644 --- a/src/models/whisper.cpp +++ b/src/models/whisper.cpp @@ -5,6 +5,8 @@ #include #include "kernels.h" +// TODO(aciddelgado): update whisper to new paradigm + namespace Generators { Whisper_Model::Whisper_Model(std::unique_ptr config, OrtEnv& ort_env) @@ -69,7 +71,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, RoamingArray s auto sequence_lengths = sequence_lengths_unk.GetCPU(); for (int i = 0; i < decoder_input_ids_.GetShape()[0]; i++) { - sequence_lengths[i] = static_cast(params_->sequence_length); + sequence_lengths[i] = 0; } input_names_.push_back("encoder_input_ids"); @@ -87,7 +89,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, RoamingArray s { auto layer_count = model_.config_->model.decoder.num_hidden_layers; - std::array shape{params_->BatchBeamSize(), model_.config_->model.decoder.num_attention_heads, params_->sequence_length, model_.config_->model.decoder.head_size}; + std::array shape{params_->BatchBeamSize(), model_.config_->model.decoder.num_attention_heads, 0, model_.config_->model.decoder.head_size}; auto type = model_.session_info_->GetOutputDataType(output_names_[kv_cache_indices]); for (int i = 0; i < layer_count * 2; i++) { @@ -270,13 +272,13 @@ RoamingArray Whisper_State::Run(int current_length, RoamingArray if (model_.session_info_->HasInput("cache_indirection")) { #if USE_CUDA - cache_indirection_ = OrtValue::CreateTensor(*model_.allocator_device_, std::array{params_->batch_size, params_->search.num_beams, params_->search.max_length}); + cache_indirection_ = OrtValue::CreateTensor(*model_.allocator_device_, std::array{params_->search.batch_size, params_->search.num_beams, params_->search.max_length}); cache_indirection_index_ = inputs_.size(); input_names_.push_back("cache_indirection"); inputs_.push_back(cache_indirection_.get()); auto data = gpu_span{cache_indirection_->GetTensorMutableData(), - static_cast(params_->batch_size) * params_->search.num_beams * params_->search.max_length}; + static_cast(params_->search.batch_size) * params_->search.num_beams * params_->search.max_length}; CudaCheck() == cudaMemsetAsync(data.data(), 0, data.size_bytes(), params_->cuda_stream); #endif } @@ -315,7 +317,8 @@ RoamingArray Whisper_State::Run(int current_length, RoamingArray void Whisper_State::UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray beam_indices, int current_length, bool search_buffers) { decoder_input_ids_.Update(next_tokens); kv_cache_.Update(beam_indices.GetCPU(), current_length); - logits_.Update(); + size_t new_length = decoder_input_ids_.GetShape()[1]; + logits_.Update(next_tokens, new_length); if (past_sequence_length_) { auto data = past_sequence_length_->GetTensorMutableData(); @@ -335,8 +338,8 @@ void Whisper_State::UpdateInputsOutputs(const RoamingArray& next_tokens gpu_span beam_indices_gpu = beam_indices.GetGPU(); cuda_unique_ptr beam_indices_ptr; if (beam_indices_gpu.empty()) { - beam_indices_ptr = CudaMallocArray(params_->batch_size, &beam_indices_gpu); - std::vector beam_indices_cpu(params_->batch_size, 0); + beam_indices_ptr = CudaMallocArray(params_->search.batch_size, &beam_indices_gpu); + std::vector beam_indices_cpu(params_->search.batch_size, 0); std::iota(beam_indices_cpu.begin(), beam_indices_cpu.end(), 0); CudaCheck() == cudaMemcpyAsync(beam_indices_gpu.data(), beam_indices_cpu.data(), beam_indices_cpu.size() * sizeof(int32_t), @@ -344,15 +347,15 @@ void Whisper_State::UpdateInputsOutputs(const RoamingArray& next_tokens } std::unique_ptr new_cache_indirection; auto cache_indirection_type = model_.session_info_->GetInputDataType("cache_indirection"); - auto cache_indirection_shape = std::array{params_->batch_size, params_->search.num_beams, params_->search.max_length}; + auto cache_indirection_shape = std::array{params_->search.batch_size, params_->search.num_beams, params_->search.max_length}; new_cache_indirection = OrtValue::CreateTensor(*model_.allocator_device_, cache_indirection_shape, cache_indirection_type); cuda::UpdateCacheIndirectionKernelLauncher(new_cache_indirection->GetTensorMutableData(), cache_indirection_->GetTensorData(), beam_indices_gpu.data(), - params_->batch_size, + params_->search.batch_size, params_->search.num_beams, - params_->sequence_length, + 0, params_->search.max_length, current_length, model_.cuda_stream_); @@ -375,7 +378,7 @@ void Whisper_State::UpdateInputsOutputs(const RoamingArray& next_tokens cuda::LaunchCopyCrossQKSingleDecodeStep(model_.cuda_stream_, cross_qk_search_buffer_->GetTensorMutableData(), output_cross_qk_ptrs_gpu_.data(), - current_length - params_->sequence_length, + current_length - 0, params_->BatchBeamSize(), model_.config_->model.decoder.num_hidden_layers, static_cast(output_cross_qk_dims[1]), @@ -396,11 +399,11 @@ void Whisper_State::Finalize() { // Instantiate final output for cross QKs auto num_alignment_heads = alignment_heads_->GetTensorTypeAndShapeInfo()->GetShape()[0]; auto cross_qk_type = model_.session_info_->GetOutputDataType("output_cross_qk_0"); - auto cross_qk_shape = std::array{params_->batch_size, params_->search.num_return_sequences, num_alignment_heads, decoded_length, 1500}; + auto cross_qk_shape = std::array{params_->search.batch_size, params_->search.num_return_sequences, num_alignment_heads, decoded_length, 1500}; cross_qk_final_ = OrtValue::CreateTensor(*model_.allocator_device_, cross_qk_shape, cross_qk_type); cuda::LaunchFinalizeCrossQK(model_.cuda_stream_, - decoded_length - params_->sequence_length, + decoded_length - 0, decoded_length, static_cast(output_cross_qk_dims[0]), params_->search.num_beams, diff --git a/src/ort_genai.h b/src/ort_genai.h index 36b56d9bd..2c810f986 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -200,15 +200,7 @@ struct OgaGeneratorParams : OgaAbstract { void SetSearchOptionBool(const char* name, bool value) { OgaCheckResult(OgaGeneratorParamsSetSearchBool(this, name, value)); - } - - void SetInputIDs(const int32_t* input_ids, size_t input_ids_count, size_t sequence_length, size_t batch_size) { - OgaCheckResult(OgaGeneratorParamsSetInputIDs(this, input_ids, input_ids_count, sequence_length, batch_size)); - } - - void SetInputSequences(const OgaSequences& sequences) { - OgaCheckResult(OgaGeneratorParamsSetInputSequences(this, &sequences)); - } + } void SetModelInput(const char* name, OgaTensor& tensor) { OgaCheckResult(OgaGeneratorParamsSetModelInput(this, name, &tensor)); @@ -226,7 +218,7 @@ struct OgaGeneratorParams : OgaAbstract { }; struct OgaGenerator : OgaAbstract { - static std::unique_ptr Create(const OgaModel& model, const OgaGeneratorParams& params) { + static std::unique_ptr Create(const OgaModel& model, OgaGeneratorParams& params) { OgaGenerator* p; OgaCheckResult(OgaCreateGenerator(&model, ¶ms, &p)); return std::unique_ptr(p); @@ -236,14 +228,22 @@ struct OgaGenerator : OgaAbstract { return OgaGenerator_IsDone(this); } - void ComputeLogits() { - OgaCheckResult(OgaGenerator_ComputeLogits(this)); + void AddInputSequences(const OgaSequences& sequences) { + OgaCheckResult(OgaGenerator_AddInputSequences(this, &sequences)); + } + + void AddInputTokens(int32_t* input_ids, size_t input_ids_count) { + OgaCheckResult(OgaGenerator_AddInputTokens(this, input_ids, input_ids_count)); } void GenerateNextToken() { OgaCheckResult(OgaGenerator_GenerateNextToken(this)); } + void RewindToLength(size_t length) { + OgaCheckResult(OgaGenerator_RewindToLength(this, length)); + } + size_t GetSequenceCount(size_t index) const { return OgaGenerator_GetSequenceCount(this, index); } diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 4ce580fad..844a2f749 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -9,6 +9,7 @@ #include "generators.h" #include "models/model.h" #include "search.h" +#include "smartptrs.h" namespace Generators { @@ -174,36 +175,6 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGen OGA_CATCH } -OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams* oga_params, const int32_t* input_ids, size_t input_ids_count, size_t sequence_length, size_t batch_size) { - OGA_TRY - auto& params = *reinterpret_cast(oga_params); - params.input_ids = std::span(input_ids, input_ids_count); - params.sequence_length = static_cast(sequence_length); - params.batch_size = static_cast(batch_size); - if (params.sequence_length * params.batch_size != input_ids_count) - throw std::runtime_error("sequence length * batch size is not equal to input_ids_count"); - return nullptr; - OGA_CATCH -} - -OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams* oga_params, const OgaSequences* p_sequences) { - OGA_TRY - auto& params = *reinterpret_cast(oga_params); - auto& sequences = *reinterpret_cast(p_sequences); - - std::vector> span_sequences; - for (size_t i = 0; i < sequences.size(); i++) { - span_sequences.emplace_back(sequences[i]); - } - - params.input_ids_owner = Generators::PadInputs(span_sequences, params.config.model.pad_token_id); - params.batch_size = static_cast(sequences.size()); - params.sequence_length = static_cast(params.input_ids_owner.size() / params.batch_size); - params.input_ids = params.input_ids_owner; - return nullptr; - OGA_CATCH -} - OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputs(OgaGeneratorParams* oga_params, const OgaNamedTensors* p_named_tensors) { OGA_TRY auto& params = *reinterpret_cast(oga_params); @@ -232,28 +203,38 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(OgaGeneratorPa OGA_CATCH } -OgaResult* OGA_API_CALL OgaGenerate(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out) { +OgaResult* OgaCreateGenerator(const OgaModel* model, OgaGeneratorParams* generator_params, OgaGenerator** out) { OGA_TRY - auto result = Generators::Generate(*reinterpret_cast(model), *reinterpret_cast(generator_params)); - *out = reinterpret_cast(std::make_unique(std::move(result)).release()); + *out = reinterpret_cast(CreateGenerator(*reinterpret_cast(model), *reinterpret_cast(generator_params)).release()); return nullptr; OGA_CATCH } -OgaResult* OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaGenerator** out) { +bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator) { + return reinterpret_cast(generator)->IsDone(); +} + +OgaResult* OGA_API_CALL OgaGenerator_AddInputSequences(OgaGenerator* oga_generator, const OgaSequences* p_sequences) { OGA_TRY - *out = reinterpret_cast(CreateGenerator(*reinterpret_cast(model), *reinterpret_cast(generator_params)).release()); + auto& generator = *reinterpret_cast(oga_generator); + auto& params = *generator.state_->params_; + auto& sequences = *reinterpret_cast(p_sequences); + + std::vector> span_sequences; + for (size_t i = 0; i < sequences.size(); i++) { + span_sequences.emplace_back(sequences[i]); + } + + auto input_ids = Generators::PadInputs(span_sequences, generator.model_->config_->model.pad_token_id); + generator.AddTokens(input_ids); return nullptr; OGA_CATCH } -bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator) { - return reinterpret_cast(generator)->IsDone(); -} - -OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator) { +OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, int32_t* input_ids, size_t input_ids_count) { OGA_TRY - reinterpret_cast(generator)->ComputeLogits(); + auto& generator = *reinterpret_cast(oga_generator); + generator.AddTokens(Generators::cpu_span(input_ids, input_ids_count)); return nullptr; OGA_CATCH } @@ -265,6 +246,13 @@ OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) OGA_CATCH } +OgaResult* OGA_API_CALL OgaGenerator_RewindToLength(OgaGenerator* generator, size_t new_length) { + OGA_TRY + reinterpret_cast(generator)->RewindToLength(new_length); + return nullptr; + OGA_CATCH +} + OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator, const char* name, OgaTensor** out) { OGA_TRY auto& generator = *reinterpret_cast(oga_generator); diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index c1d03f8e1..bdba5b9e4 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -191,26 +191,6 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchNumber(OgaGenerato OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* generator_params, const char* name, bool value); OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* generator_params, int32_t max_batch_size); -/* - * \brief Sets the input ids for the generator params. The input ids are used to seed the generation. - * \param[in] generator_params The generator params to set the input ids on. - * \param[in] input_ids The input ids array of size input_ids_count = batch_size * sequence_length. - * \param[in] input_ids_count The total number of input ids. - * \param[in] sequence_length The sequence length of the input ids. - * \param[in] batch_size The batch size of the input ids. - * \return OgaResult containing the error message if the setting of the input ids failed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams* generator_params, const int32_t* input_ids, - size_t input_ids_count, size_t sequence_length, size_t batch_size); - -/* - * \brief Sets the input id sequences for the generator params. The input id sequences are used to seed the generation. - * \param[in] generator_params The generator params to set the input ids on. - * \param[in] sequences The input id sequences. - * \return OgaResult containing the error message if the setting of the input id sequences failed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams* generator_params, const OgaSequences* sequences); - OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputs(OgaGeneratorParams* generator_params, const OgaNamedTensors* named_tensors); /* @@ -231,7 +211,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(Oga * \param[out] out The created generator. * \return OgaResult containing the error message if the generator creation failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out); +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(const OgaModel* model, OgaGeneratorParams* params, OgaGenerator** out); /* * \brief Destroys the given generator. @@ -246,14 +226,39 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* generator); */ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator); +/* + * \brief Adds the input ids to the generator. The input ids are used to seed the generation. + * \param[in] oga_generator The generator to add the input ids to. + * \param[in] p_sequences The input id sequences. + * \return OgaResult containing the error message if the setting of the input ids failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AddInputSequences(OgaGenerator* oga_generator, const OgaSequences* p_sequences); + +/* + * \brief Adds the input ids to the generator. The input ids are used to seed the generation. + * \param[in] oga_generator The generator to add the input ids to. + * \param[in] input_ids The input ids to add. + * \param[in] input_ids_count The number of input ids to add (batch_size * sequence_length). + * \return OgaResult containing the error message if the setting of the input ids failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, int32_t* input_ids, size_t input_ids_count); + /* * \brief Computes the logits from the model based on the input ids and the past state. The computed logits are stored in the generator. * \param[in] generator The generator to compute the logits for. * \return OgaResult containing the error message if the computation of the logits failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator); OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator); +/* + * \brief Rewinds the generator to the given length. This is useful when the user wants to rewind the generator to a specific length + * and continue generating from that point. + * \param[in] generator The generator to rewind to the given length. + * \param[in] new_length The new length to rewind the generator to. + * \return OgaResult containing the error message if the rewinding failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_RewindToLength(OgaGenerator* generator, size_t new_length); + /* * \brief Returns a copy of the model output identified by the given name as an OgaTensor on CPU. The buffer is owned by returned OgaTensor * and will be released when the OgaTensor is destroyed diff --git a/src/python/python.cpp b/src/python/python.cpp index ffbbc5d36..4247c17b6 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -9,6 +9,7 @@ #include "../search.h" #include "../models/model.h" #include "../logging.h" +#include "../smartptrs.h" using namespace pybind11::literals; @@ -31,7 +32,7 @@ struct npy_format_descriptor { } // namespace pybind11 template -std::span ToSpan(pybind11::array_t v) { +Generators::cpu_span ToSpan(pybind11::array_t v) { if constexpr (std::is_const_v) return {v.data(), static_cast(v.size())}; else @@ -219,23 +220,7 @@ struct PyGeneratorParams { std::shared_ptr params_; - // Turn the python py_input_ids_ into the low level parameters void Prepare() { - // TODO: This will switch to using the variant vs being ifs - if (py_input_ids_.size() != 0) { - if (py_input_ids_.ndim() == 1) { // Just a 1D array - params_->batch_size = 1; - params_->sequence_length = static_cast(py_input_ids_.shape(0)); - } else { - if (py_input_ids_.ndim() != 2) - throw std::runtime_error("Input IDs can only be 1 or 2 dimensional"); - - params_->batch_size = static_cast(py_input_ids_.shape(0)); - params_->sequence_length = static_cast(py_input_ids_.shape(1)); - } - params_->input_ids = ToSpan(py_input_ids_); - } - if (py_whisper_input_features_.size() != 0) { GeneratorParams::Whisper& whisper = params_->inputs.emplace(); whisper.input_features = std::make_shared(ToOrtValue(py_whisper_input_features_)); @@ -277,7 +262,6 @@ struct PyGeneratorParams { params_->TryGraphCapture(max_batch_size.cast()); } - pybind11::array_t py_input_ids_; pybind11::array py_whisper_input_features_; pybind11::array py_alignment_heads_; @@ -293,8 +277,7 @@ struct PyNamedTensors { struct PyGenerator { PyGenerator(Model& model, PyGeneratorParams& params) { - params.Prepare(); - generator_ = CreateGenerator(model, params); + generator_ = CreateGenerator(model, *params.params_); } pybind11::array_t GetNextTokens() { @@ -307,18 +290,22 @@ struct PyGenerator { return ToPython(py_sequence_.GetCPU()); } - void ComputeLogits() { - generator_->ComputeLogits(); - } - pybind11::array GetOutput(const std::string& name) { return ToNumpy(generator_->state_->GetOutput(name.c_str()), *(generator_->model_)); } + void AddTokens(pybind11::array_t tokens) { + generator_->AddTokens(ToSpan(tokens)); + } + void GenerateNextToken() { generator_->GenerateNextToken(); } + void RewindToLength(size_t new_length) { + generator_->RewindToLength(new_length); + } + bool IsDone() const { return generator_->IsDone(); } @@ -379,7 +366,6 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def_property_readonly("pad_token_id", [](const PyGeneratorParams& v) { return v.params_->config.model.pad_token_id; }) .def_property_readonly("eos_token_id", [](const PyGeneratorParams& v) { return v.params_->config.model.eos_token_id; }) .def_property_readonly("vocab_size", [](const PyGeneratorParams& v) { return v.params_->config.model.vocab_size; }) - .def_readwrite("input_ids", &PyGeneratorParams::py_input_ids_) // TODO(baijumeswani): Rename/redesign the whisper_input_features to be more generic .def_readwrite("whisper_input_features", &PyGeneratorParams::py_whisper_input_features_) .def_readwrite("alignment_heads", &PyGeneratorParams::py_alignment_heads_) @@ -422,7 +408,6 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def(pybind11::init([](const std::string& config_path) { return CreateModel(GetOrtEnv(), config_path.c_str()); })) - .def("generate", [](Model& model, PyGeneratorParams& params) { params.Prepare(); return Generate(model, params); }) .def_property_readonly( "device_type", [](const Model& model) { return to_string(model.device_type_); }, "The device type the model is running on") .def("create_multimodal_processor", [](const Model& model) { return model.CreateMultiModalProcessor(); }); @@ -430,9 +415,10 @@ PYBIND11_MODULE(onnxruntime_genai, m) { pybind11::class_(m, "Generator") .def(pybind11::init()) .def("is_done", &PyGenerator::IsDone) - .def("compute_logits", &PyGenerator::ComputeLogits) .def("get_output", &PyGenerator::GetOutput) + .def("add_input_tokens", &PyGenerator::AddTokens) .def("generate_next_token", &PyGenerator::GenerateNextToken) + .def("rewind_to_length", &PyGenerator::RewindToLength) .def("get_next_tokens", &PyGenerator::GetNextTokens) .def("get_sequence", &PyGenerator::GetSequence); diff --git a/src/search.cpp b/src/search.cpp index 6dc938dc4..84338bb72 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -9,7 +9,7 @@ namespace Generators { Search_Cpu::Search_Cpu(const GeneratorParams& params) : Search{params}, - sequences_{params.input_ids, params.batch_size, params.search.num_beams, params_->search.max_length} { + sequences_{params.search.batch_size, params.search.num_beams, params_->search.max_length} { auto batch_beam_size = params.BatchBeamSize(); sequence_lengths_buffer_ = AllocateArray(batch_beam_size, &sequence_lengths_); } @@ -26,10 +26,10 @@ GreedySearch_Cpu::GreedySearch_Cpu(const GeneratorParams& params) gen_.seed(seq); } - next_tokens_buffer_ = AllocateArray(params.batch_size, &next_tokens_); + next_tokens_buffer_ = AllocateArray(params.search.batch_size, &next_tokens_); memset(next_tokens_.data(), 0, next_tokens_.size_bytes()); - eos_seen_buffer_ = AllocateArray(params.batch_size, &eos_seen_); + eos_seen_buffer_ = AllocateArray(params.search.batch_size, &eos_seen_); memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); } @@ -75,7 +75,7 @@ void BeamSearch_Cpu::SelectTop() { // TODO(aciddelgado): use thread pool to parallel int offset = 0; int batch_beam_index = 0; - for (int i = 0; i < params_->batch_size; i++) { + for (int i = 0; i < params_->search.batch_size; i++) { for (int j = 0; j < params_->search.num_beams; j++, batch_beam_index++) { for (int k = 0; k < params_->config.model.vocab_size; k++, offset++) { next_token_scores_[offset] += beam_scores[batch_beam_index]; @@ -92,16 +92,16 @@ void BeamSearch_Cpu::SelectTop() { bool operator<(const ScoreIndex& s) const { return score < s.score; } }; - auto scores = std::make_unique(top_k * params_->batch_size); // Score of top_k tokens - auto indices = std::make_unique(top_k * params_->batch_size); // beam index of top_k tokens - auto tokens = std::make_unique(top_k * params_->batch_size); // token id of top_k tokens + auto scores = std::make_unique(top_k * params_->search.batch_size); // Score of top_k tokens + auto indices = std::make_unique(top_k * params_->search.batch_size); // beam index of top_k tokens + auto tokens = std::make_unique(top_k * params_->search.batch_size); // token id of top_k tokens - auto next_scores = std::span(scores.get(), top_k * params_->batch_size); - auto next_indices = std::span(indices.get(), top_k * params_->batch_size); - auto next_tokens = std::span(tokens.get(), top_k * params_->batch_size); + auto next_scores = std::span(scores.get(), top_k * params_->search.batch_size); + auto next_indices = std::span(indices.get(), top_k * params_->search.batch_size); + auto next_tokens = std::span(tokens.get(), top_k * params_->search.batch_size); // TODO(aciddelgado): Optimize this top k with partial sort - for (size_t batch_index = 0; batch_index < static_cast(params_->batch_size); batch_index++) { + for (size_t batch_index = 0; batch_index < static_cast(params_->search.batch_size); batch_index++) { std::priority_queue> queue; auto token_scores_sub = next_token_scores_.subspan(batch_index * params_->search.num_beams * params_->config.model.vocab_size, static_cast(params_->search.num_beams) * params_->config.model.vocab_size); for (int i = 0; i < token_scores_sub.size(); i++) { @@ -134,7 +134,7 @@ void BeamSearch_Cpu::SelectTop() { void GreedySearch_Cpu::SelectTop() { // next_tokens = torch.argmax(scores, dim=-1) - for (size_t batch_id = 0; batch_id < params_->batch_size; batch_id++) { + for (size_t batch_id = 0; batch_id < params_->search.batch_size; batch_id++) { if (PadIfAlreadyEOS(batch_id)) { continue; } @@ -148,7 +148,7 @@ void GreedySearch_Cpu::SelectTop() { } void GreedySearch_Cpu::SampleTopK(int k, float temperature) { - for (size_t batch_id = 0; batch_id < params_->batch_size; batch_id++) { + for (size_t batch_id = 0; batch_id < params_->search.batch_size; batch_id++) { std::span const scores = next_token_scores_.subspan(batch_id * params_->config.model.vocab_size, params_->config.model.vocab_size); SoftMax(scores, temperature); // Find the top K scores @@ -164,7 +164,7 @@ void GreedySearch_Cpu::SampleTopK(int k, float temperature) { void GreedySearch_Cpu::SampleTopP(float p, float temperature) { std::uniform_real_distribution dis(0, p); - for (size_t batch_id = 0; batch_id < params_->batch_size; batch_id++) { + for (size_t batch_id = 0; batch_id < params_->search.batch_size; batch_id++) { if (PadIfAlreadyEOS(batch_id)) { continue; } @@ -193,7 +193,7 @@ void GreedySearch_Cpu::SampleTopP(float p, float temperature) { void GreedySearch_Cpu::SampleTopKTopP(int k, float p, float temperature) { std::uniform_real_distribution dis(0, p); - for (size_t batch_id = 0; batch_id < params_->batch_size; batch_id++) { + for (size_t batch_id = 0; batch_id < params_->search.batch_size; batch_id++) { if (PadIfAlreadyEOS(batch_id)) { continue; } @@ -230,9 +230,9 @@ bool GreedySearch_Cpu::PadIfAlreadyEOS(size_t batch_id) { return true; } -void GreedySearch_Cpu::SetNextToken(size_t batch_id, int32_t token) { +void GreedySearch_Cpu::SetNextToken(size_t batch_id, int32_t token, bool check_eos) { next_tokens_[batch_id] = token; - if (token == params_->config.model.eos_token_id) { + if (check_eos && token == params_->config.model.eos_token_id) { eos_seen_[batch_id] = true; if (g_log.enabled && g_log.hit_eos) Log("hit_eos", "EOS seen on batch " + std::to_string(batch_id)); @@ -252,6 +252,57 @@ void GreedySearch_Cpu::AppendNextTokensToSequences() { } } +void GreedySearch_Cpu::SetUserTokens(RoamingArray next_tokens) { + // Reset done count/state + done_ = false; + not_done_count_ = params_->search.batch_size; + memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); + + // Set user-defined next tokens + auto next_tokens_cpu = next_tokens.GetCPU(); + auto batch_size = params_->search.batch_size; + auto tokens_count_per_batch = next_tokens_cpu.size() / batch_size; + for (size_t j = 0; j < tokens_count_per_batch; j++) { + for (size_t i = 0; i < batch_size; i++) { + SetNextToken(i, next_tokens_cpu[i * tokens_count_per_batch + j], false); + } + AppendNextTokensToSequences(); + } +} + +void GreedySearch_Cpu::RewindTo(size_t index) { + sequences_.RewindTo(index); + done_ = false; + not_done_count_ = params_->search.batch_size; + memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); + // Set next tokens to the last tokens in the sequence + if (index > 0) { + sequences_.GetLastTokens(next_tokens_); + } + else + memset(next_tokens_.data(), 0, next_tokens_.size_bytes()); +} + +void GreedySearch_Cpu::DropLastTokens(size_t num_tokens) { + auto sequences_cpu = sequences_.GetSequences(); + auto new_sequence_length = sequences_.GetSequenceLength() - num_tokens; + for (size_t i = 0; i < params_->search.batch_size; ++i) { + if (!eos_seen_[i]) + continue; + auto sequence_cpu = sequences_cpu.subspan(i * params_->search.max_length + new_sequence_length, num_tokens); + for (size_t j = 0; j < num_tokens; ++j) { + if (sequence_cpu[j] == params_->config.model.eos_token_id) { + not_done_count_++; + done_ = false; + eos_seen_[i] = false; + if (g_log.enabled && g_log.hit_eos) + Log("hit_eos", "Reverted EOS seen on batch " + std::to_string(i)); + } + } + } + sequences_.DropLastTokens({num_tokens}); +} + bool BeamSearch_Cpu::IsDone() const { if (beam_scorer_->IsDone()) { return true; diff --git a/src/search.h b/src/search.h index 9ca2bcc99..af57fb0e9 100644 --- a/src/search.h +++ b/src/search.h @@ -27,6 +27,13 @@ struct Search : LeakChecked { virtual void ApplyMinLength(int min_length) = 0; virtual void ApplyRepetitionPenalty(float penalty) = 0; + // Set user input tokens + virtual void SetUserTokens(RoamingArray next_tokens) { assert(false); }; + // To be used for rewind + virtual void RewindTo(size_t index) { assert(false); }; + // To be used for rewind + virtual void DropLastTokens(size_t num_tokens) { assert(false); }; + std::shared_ptr params_; }; @@ -68,17 +75,24 @@ struct GreedySearch_Cpu : Search_Cpu { void SampleTopP(float p, float temperature) override; void SampleTopKTopP(int /*k*/, float /*p*/, float /*temperature*/) override; + // Used by continuous decoding search. + void SetUserTokens(RoamingArray next_tokens) override; + void RewindTo(size_t index) override; + void DropLastTokens(size_t num_tokens) override; + + protected: + void SetNextToken(size_t batch_id, int32_t token, bool check_eos = true); + void AppendNextTokensToSequences(); + private: bool PadIfAlreadyEOS(size_t batch_id); - void SetNextToken(size_t batch_id, int32_t token); - void AppendNextTokensToSequences(); std::unique_ptr next_tokens_buffer_; std::unique_ptr temp_topk_buffer_; std::span eos_seen_; // shape (batch_size) std::unique_ptr eos_seen_buffer_; - int not_done_count_{params_->batch_size}; // When zero, every batch entry is done (starts at batch_size_) + int not_done_count_{params_->search.batch_size}; // When zero, every batch entry is done (starts at batch_size_) std::mt19937 gen_; }; diff --git a/src/search_cuda.cpp b/src/search_cuda.cpp index 23a232bb3..d9c6a6c72 100644 --- a/src/search_cuda.cpp +++ b/src/search_cuda.cpp @@ -17,7 +17,7 @@ void OnCudaError(cudaError_t error) { Search_Cuda::Search_Cuda(const GeneratorParams& params) : Search{params}, - sequences_{params.input_ids, params.batch_size, params.search.num_beams, params_->search.max_length, params_->cuda_stream} { + sequences_{params.search.batch_size, params.search.num_beams, params_->search.max_length, params_->cuda_stream} { auto batch_beam_size = params.BatchBeamSize(); sequence_lengths_buffer_ = std::make_unique(batch_beam_size); sequence_lengths_ = cpu_span(sequence_lengths_buffer_.get(), batch_beam_size); @@ -31,7 +31,7 @@ Search_Cuda::Search_Cuda(const GeneratorParams& params) GreedySearch_Cuda::GreedySearch_Cuda(const GeneratorParams& params) : Search_Cuda{params} { - next_tokens_buffer_ = CudaMallocArray(params.batch_size, &next_tokens_); + next_tokens_buffer_ = CudaMallocArray(params.search.batch_size, &next_tokens_); cudaMemsetAsync(next_tokens_.data(), 0, next_tokens_.size_bytes(), params_->cuda_stream); unsigned long long random_seed; @@ -39,7 +39,7 @@ GreedySearch_Cuda::GreedySearch_Cuda(const GeneratorParams& params) random_seed = params_->search.random_seed; else random_seed = std::random_device{}(); - samplingdata_ = std::make_unique(random_seed, params_->batch_size, params_->config.model.vocab_size, params_->cuda_stream); + samplingdata_ = std::make_unique(random_seed, params_->search.batch_size, params_->config.model.vocab_size, params_->cuda_stream); } BeamSearch_Cuda::BeamSearch_Cuda(const GeneratorParams& params) @@ -97,7 +97,7 @@ void BeamSearch_Cuda::SelectTop() { // Add beam score to next token scores. Corresponding python code is like: // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) cuda::LaunchAddProbsKernel(softmax_buffer_.get(), beam_scores.data(), - params_->batch_size, params_->search.num_beams, params_->config.model.vocab_size, params_->cuda_stream); + params_->search.batch_size, params_->search.num_beams, params_->config.model.vocab_size, params_->cuda_stream); if (params_->search.num_beams <= 32) { constexpr size_t max_parts_of_vocab = 128; @@ -109,7 +109,7 @@ void BeamSearch_Cuda::SelectTop() { int32_t* topk_tokens_2nd_stage = reinterpret_cast(topk_scores_2nd_stage + candidate_count); cuda::BeamSearchTopK(softmax_buffer_.get(), - params_->batch_size, + params_->search.batch_size, params_->search.num_beams, params_->config.model.vocab_size, 2 * params_->search.num_beams, @@ -145,40 +145,41 @@ void BeamSearch_Cuda::SelectTop() { } void GreedySearch_Cuda::SelectTop() { - std::span scores = next_token_scores_.subspan(0, params_->batch_size * params_->config.model.vocab_size); - cuda::GetSample(samplingdata_.get(), params_->cuda_stream, next_tokens_.data(), scores.data(), int(scores.size() / params_->batch_size), - params_->batch_size, 1, 0.0, 1.0); + std::span scores = next_token_scores_.subspan(0, params_->search.batch_size * params_->config.model.vocab_size); + cuda::GetSample(samplingdata_.get(), params_->cuda_stream, next_tokens_.data(), scores.data(), int(scores.size() / params_->search.batch_size), + params_->search.batch_size, 1, 0.0, 1.0); CheckForEOS(); AppendNextTokensToSequences(); } void GreedySearch_Cuda::SampleTopP(float p, float temperature) { - std::span scores = next_token_scores_.subspan(0, params_->batch_size * params_->config.model.vocab_size); - cuda::GetSample(samplingdata_.get(), params_->cuda_stream, next_tokens_.data(), scores.data(), int(scores.size() / params_->batch_size), - params_->batch_size, -1, p, temperature); + std::span scores = next_token_scores_.subspan(0, params_->search.batch_size * params_->config.model.vocab_size); + cuda::GetSample(samplingdata_.get(), params_->cuda_stream, next_tokens_.data(), scores.data(), int(scores.size() / params_->search.batch_size), + params_->search.batch_size, -1, p, temperature); CheckForEOS(); AppendNextTokensToSequences(); } void GreedySearch_Cuda::SampleTopK(int k, float temperature) { - std::span scores = next_token_scores_.subspan(0, params_->batch_size * params_->config.model.vocab_size); - cuda::GetSample(samplingdata_.get(), params_->cuda_stream, next_tokens_.data(), scores.data(), int(scores.size() / params_->batch_size), - params_->batch_size, k, 0.0, temperature); + std::span scores = next_token_scores_.subspan(0, params_->search.batch_size * params_->config.model.vocab_size); + cuda::GetSample(samplingdata_.get(), params_->cuda_stream, next_tokens_.data(), scores.data(), int(scores.size() / params_->search.batch_size), + params_->search.batch_size, k, 0.0, temperature); CheckForEOS(); AppendNextTokensToSequences(); } void GreedySearch_Cuda::SampleTopKTopP(int k, float p, float temperature) { - std::span scores = next_token_scores_.subspan(0, params_->batch_size * params_->config.model.vocab_size); - cuda::GetSample(samplingdata_.get(), params_->cuda_stream, next_tokens_.data(), scores.data(), int(scores.size() / params_->batch_size), - params_->batch_size, k, p, temperature); + std::span scores = next_token_scores_.subspan(0, params_->search.batch_size * params_->config.model.vocab_size); + cuda::GetSample(samplingdata_.get(), params_->cuda_stream, next_tokens_.data(), scores.data(), int(scores.size() / params_->search.batch_size), + params_->search.batch_size, k, p, temperature); CheckForEOS(); AppendNextTokensToSequences(); } void GreedySearch_Cuda::CheckForEOS() { assert(next_tokens_.size() == eos_meet_.size()); - cuda::Launch_CheckForEOS(next_tokens_.data(), static_cast(next_tokens_.size()), eos_meet_.data(), params_->config.model.eos_token_id, params_->config.model.pad_token_id, done_cpu_.get(), params_->cuda_stream); + // Don't replace EOS with pad for batch_size == 1 for continuous decoding mode + cuda::Launch_CheckForEOSAndPad(next_tokens_.data(), static_cast(next_tokens_.size()), eos_meet_.data(), params_->config.model.eos_token_id, params_->search.batch_size > 1 ? params_->config.model.pad_token_id : params_->config.model.eos_token_id, done_cpu_.get(), params_->cuda_stream); } void GreedySearch_Cuda::AppendNextTokensToSequences() { @@ -234,7 +235,7 @@ void GreedySearch::Finalize(size_t num_return_sequences, std::span outp // Copy the sequences to output std::span output{ output_sequences_->GetTensorMutableData(), shape_count}; - for (int batch_id = 0; batch_id < params_->batch_size; ++batch_id) { + for (int batch_id = 0; batch_id < params_->search.batch_size; ++batch_id) { auto batch_output = output.subspan( static_cast(batch_id) * params_->max_length, params_->max_length); @@ -253,6 +254,34 @@ std::span Search_Cuda::GetScores() { return next_token_scores_; } +// Set user input tokens (batch_beam_size, sequence_length) +void GreedySearch_Cuda::SetUserTokens(RoamingArray next_tokens) { + cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), params_->cuda_stream); + *done_cpu_ = false; + + auto next_tokens_gpu = next_tokens.GetGPU(); + auto batch_size = params_->search.batch_size; + auto tokens_count_per_batch = next_tokens_gpu.size() / batch_size; + sequences_.AppendUserTokensToSequences(next_tokens_gpu); + + if (sequences_.GetSequenceLength() == params_->search.max_length) { + if (g_log.enabled && g_log.hit_max_length) + Log("hit_max_length", "greedy cuda hit"); + *done_cpu_ = true; + } +} + +void GreedySearch_Cuda::RewindTo(size_t index) { + sequences_.RewindTo(index); + cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), params_->cuda_stream); + *done_cpu_ = false; + if (index > 0) { + sequences_.GetLastTokens(next_tokens_); + } + else + cudaMemsetAsync(next_tokens_.data(), 0, params_->search.batch_size * sizeof(int32_t), params_->cuda_stream); +} + void Search_Cuda::ApplyMinLength(int min_length) { if (sequences_.GetSequenceLength() >= min_length) return; @@ -265,7 +294,7 @@ void Search_Cuda::ApplyRepetitionPenalty(float penalty) { return; cuda::LaunchRepetitionPenaltyProcessor(sequences_.GetSequences().data(), - GetScores().data(), params_->batch_size, params_->search.num_beams, params_->config.model.vocab_size, + GetScores().data(), params_->search.batch_size, params_->search.num_beams, params_->config.model.vocab_size, params_->search.max_length, GetSequenceLength(), penalty, params_->cuda_stream); } diff --git a/src/search_cuda.cu b/src/search_cuda.cu index 24fbaf914..0ea68c369 100644 --- a/src/search_cuda.cu +++ b/src/search_cuda.cu @@ -38,7 +38,7 @@ struct ArgMaxDataImpl : ArgMaxData { cuda_unique_ptr> argmaxen_owner_; }; -__global__ void CheckForEOS(int32_t* next_tokens, int next_tokens_count, bool* eos_meet, int eos_token_id, int pad_token_id, bool* done_cpu) { +__global__ void CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_meet, int eos_token_id, int pad_token_id, bool* done_cpu) { // Look for EOS tokens, if seen set EOS flag and replace with pad token for (size_t batch_id = 0; batch_id < next_tokens_count; ++batch_id) { if (next_tokens[batch_id] == eos_token_id || eos_meet[batch_id] == true) { @@ -64,8 +64,8 @@ __global__ void CheckForEOS(int32_t* next_tokens, int next_tokens_count, bool* e } } -void Launch_CheckForEOS(int32_t* next_tokens, int next_tokens_count, bool* eos_meet, int eos_token_id, int pad_token_id, bool* done_cpu, cudaStream_t stream) { - CheckForEOS<<<1, 1, 0, stream>>>(next_tokens, next_tokens_count, eos_meet, eos_token_id, pad_token_id, done_cpu); +void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_meet, int eos_token_id, int pad_token_id, bool* done_cpu, cudaStream_t stream) { + CheckForEOSAndPad<<<1, 1, 0, stream>>>(next_tokens, next_tokens_count, eos_meet, eos_token_id, pad_token_id, done_cpu); } __global__ void AddProbsKernel(float* log_probs, diff --git a/src/search_cuda.cuh b/src/search_cuda.cuh index d3662ff17..a6db5e500 100644 --- a/src/search_cuda.cuh +++ b/src/search_cuda.cuh @@ -6,7 +6,7 @@ struct ArgMaxData { virtual ~ArgMaxData() = default; }; -void Launch_CheckForEOS(int32_t* next_tokens, int next_tokens_count, bool* eos_meet, int eos_token_id, int pad_token_id, bool* done_cpu, cudaStream_t stream); +void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_meet, int eos_token_id, int pad_token_id, bool* done_cpu, cudaStream_t stream); void LaunchAddProbsKernel(float* log_probs, float* cum_log_probs, const int batch_size, const int num_beams, const int vocab_size, cudaStream_t stream); void LaunchSetScoreProcessor(float* next_token_scores, int batch_beam_size, int vocab_size, int token, float score, cudaStream_t stream); void LaunchRepetitionPenaltyProcessor(const int32_t* sequences, float* next_token_scores, int batch_size, int num_beams, int vocab_size, int max_sequence_length, int current_sequence_length, float repetition_penalty, cudaStream_t stream); diff --git a/src/search_cuda.h b/src/search_cuda.h index 8a699b880..bcccb14ca 100644 --- a/src/search_cuda.h +++ b/src/search_cuda.h @@ -52,6 +52,8 @@ struct GreedySearch_Cuda : Search_Cuda { void SampleTopK(int k, float t) override; void SampleTopP(float p, float t) override; void SampleTopKTopP(int k, float p, float t) override; + void SetUserTokens(RoamingArray next_tokens) override; // shape (batch_size, sequence_length) + void RewindTo(size_t index) override; private: void CheckForEOS(); diff --git a/src/sequences.cpp b/src/sequences.cpp index 39354a56c..b436bc3c6 100644 --- a/src/sequences.cpp +++ b/src/sequences.cpp @@ -6,11 +6,10 @@ namespace Generators { -Sequences::Sequences(std::span input_sequences, int batch_size, int beam_size, int max_length) +Sequences::Sequences(int batch_size, int beam_size, int max_length) : batch_beam_size_{batch_size * beam_size}, max_length_{max_length}, - current_length_{static_cast(input_sequences.size()) / batch_size} { - assert(current_length_ * batch_size == input_sequences.size()); // Ensure size divided perfectly + current_length_{0} { const size_t sequences_size = static_cast(batch_beam_size_) * max_length; if (beam_size == 1) { @@ -21,16 +20,6 @@ Sequences::Sequences(std::span input_sequences, int batch_size, i sequences_ = cpu_span(sequences_buffer_.get(), sequences_size); sequences_next_ = cpu_span(sequences_buffer_.get() + sequences_size, sequences_size); } - - // The original inputs are not expanded, this expands them in place into the sequences - for (size_t batch = 0; batch < batch_size; batch++) { - for (size_t beam = 0; beam < beam_size; beam++) { - for (int j = 0; j < current_length_; j++) { - sequences_[(batch * beam_size + beam) * max_length + j] = - static_cast(input_sequences[batch * current_length_ + j]); - } - } - } } cpu_span Sequences::GetSequence(size_t batch_beam_index) { @@ -73,4 +62,20 @@ void Sequences::AppendNextTokenToSequences(std::span next_tokens) ++current_length_; } +void Sequences::GetLastTokens(cpu_span& last_tokens) { + for (int i = 0; i < batch_beam_size_; i++) { + last_tokens[i] = sequences_[i * max_length_ + current_length_ - 1]; + } +} + +void Sequences::RewindTo(size_t index) { + current_length_ = static_cast(index); + assert(current_length_ >= 0); +} + +void Sequences::DropLastTokens(size_t num_tokens) { + current_length_ -= static_cast(num_tokens); + assert(current_length_ >= 0); +} + } // namespace Generators diff --git a/src/sequences.h b/src/sequences.h index 5407a3bc1..86ca444ba 100644 --- a/src/sequences.h +++ b/src/sequences.h @@ -3,7 +3,7 @@ namespace Generators { // This class keeps track of sequences generated. struct Sequences { - Sequences(std::span input_sequence, int batch_size, int beam_size, int max_length); + Sequences(int batch_size, int beam_size, int max_length); // Returns a sequence of word IDs for a given beam index ( beam_index < batch_beam_size). cpu_span GetSequence(size_t batch_beam_index); @@ -19,6 +19,13 @@ struct Sequences { // Used by Greedy search: void AppendNextTokenToSequences(std::span next_tokens); + // Return Token IDs of last token in each sequence + void GetLastTokens(cpu_span& last_tokens); + // Rewind sequences to ith token + void RewindTo(size_t index); + // TODO(aciddelgado): To be used for rewind + void DropLastTokens(size_t num_tokens); + private: std::unique_ptr sequences_buffer_; diff --git a/src/sequences_cuda.cpp b/src/sequences_cuda.cpp index 8f7a643e5..ef00da7b4 100644 --- a/src/sequences_cuda.cpp +++ b/src/sequences_cuda.cpp @@ -8,14 +8,15 @@ namespace Generators { namespace cuda { void Launch_ExpandInputSequences(std::span input_sequences, std::span sequences, int batch_size, int beam_size, int current_length, int max_length, cudaStream_t stream); void Launch_AppendNextTokenToSequences(std::span next_tokens, std::span sequences, int batch_beam_size, int current_length, int max_length, cudaStream_t stream); +void Launch_AppendUserTokensToSequences(std::span next_tokens, std::span sequences, int batch_beam_size, int past_length, int new_length, int max_length, cudaStream_t stream); +void Launch_GetLastTokens(std::span sequences, std::span last_tokens, int batch_beam_size, int current_length, int max_length, cudaStream_t stream); } // namespace cuda -Sequences_Cuda::Sequences_Cuda(std::span input_sequences, int batch_size, int beam_size, int max_length, cudaStream_t stream) +Sequences_Cuda::Sequences_Cuda(int batch_size, int beam_size, int max_length, cudaStream_t stream) : stream_{stream}, batch_beam_size_{batch_size * beam_size}, max_length_{max_length}, - current_length_{static_cast(input_sequences.size()) / batch_size} { - assert(current_length_ * batch_size == input_sequences.size()); // Ensure size divided perfectly + current_length_{0} { size_t sequences_size = batch_beam_size_ * max_length; if (beam_size == 1) { @@ -30,10 +31,7 @@ Sequences_Cuda::Sequences_Cuda(std::span input_sequences, int bat // TODO: input_sequences will be in cuda memory in the future, for now make a temp copy gpu_span input_sequences_gpu; - auto input_sequences_temp = CudaMallocArray(input_sequences.size(), &input_sequences_gpu); - cudaMemcpyAsync(input_sequences_gpu.data(), input_sequences.data(), input_sequences.size_bytes(), cudaMemcpyHostToDevice, stream); - cuda::Launch_ExpandInputSequences(input_sequences_gpu, sequences_, batch_size, beam_size, current_length_, max_length, stream_); cudaStreamSynchronize(stream); // Until we remove the todo above, wait for this to complete as input_sequences_gpu is on the stack } @@ -57,6 +55,23 @@ void Sequences_Cuda::AppendNextTokenToSequences(std::span next_to ++current_length_; } +void Sequences_Cuda::AppendUserTokensToSequences(gpu_span user_tokens) { + size_t new_length = user_tokens.size() / batch_beam_size_; + size_t past_length = current_length_; + cuda::Launch_AppendUserTokensToSequences(user_tokens, sequences_, batch_beam_size_, past_length, new_length, max_length_, stream_); + current_length_ += new_length; +} + +void Sequences_Cuda::RewindTo(size_t index) { + current_length_ = index; + assert(current_length_ >= 0); +} + +void Sequences_Cuda::GetLastTokens(gpu_span& last_tokens) { + // TODO(aciddelgado): throw error when no last tokens + cuda::Launch_GetLastTokens(sequences_, last_tokens, batch_beam_size_, current_length_, max_length_, stream_); +} + void Sequences_Cuda::AfterDeviceAppendedNextToken() { ++current_length_; diff --git a/src/sequences_cuda.cu b/src/sequences_cuda.cu index dfbc8aae1..65e2d6e4b 100644 --- a/src/sequences_cuda.cu +++ b/src/sequences_cuda.cu @@ -32,5 +32,32 @@ void Launch_AppendNextTokenToSequences(std::span next_tokens, std AppendNextTokenToSequences<<<1, 1, 0, stream>>>(next_tokens.data(), sequences.data(), batch_beam_size, current_length, max_length); } +// TODO(aciddelgado): parallelize this kernel. +__global__ void AppendUserTokensToSequences(const int32_t* user_tokens, int32_t* sequences, int batch_beam_size, int past_length, int new_length, int max_length) { + // Append user tokens to each sequence. + for (int i = 0; i < batch_beam_size; i++) { + for (int j = 0; j < new_length; j++) { + sequences[i * max_length + past_length + j] = user_tokens[i * new_length + j]; + } + } +} + +void Launch_AppendUserTokensToSequences(std::span user_tokens, std::span sequences, int batch_beam_size, int past_length, int new_length, int max_length, cudaStream_t stream) { + AppendUserTokensToSequences<<<1, 1, 0, stream>>>(user_tokens.data(), sequences.data(), batch_beam_size, past_length, new_length, max_length); +} + +// TODO(aciddelgado): parallelize this kernel. +__global__ void GetLastTokens(const int32_t* sequences, int32_t* last_tokens, int batch_beam_size, int current_length, int max_length) { + // Get the last token of each sequence. + for (int i = 0; i < batch_beam_size; i++) { + last_tokens[i] = sequences[i * max_length + current_length - 1]; + } +} + +void Launch_GetLastTokens(std::span sequences, std::span last_tokens, int batch_beam_size, int current_length, int max_length, cudaStream_t stream) { + // Get the last token of each sequence. + GetLastTokens<<<1, 1, 0, stream>>>(sequences.data(), last_tokens.data(), batch_beam_size, current_length, max_length); +} + } // namespace cuda } // namespace Generators diff --git a/src/sequences_cuda.h b/src/sequences_cuda.h index 8dc9038c3..dd1d14da9 100644 --- a/src/sequences_cuda.h +++ b/src/sequences_cuda.h @@ -3,7 +3,7 @@ namespace Generators { // This class keeps track of sequences generated. struct Sequences_Cuda { - Sequences_Cuda(std::span input_sequences, int batch_size, int beam_size, int max_length, cudaStream_t stream); + Sequences_Cuda(int batch_size, int beam_size, int max_length, cudaStream_t stream); // Returns a sequence of word IDs for a given beam index ( beam_index < batch_beam_size). RoamingArray GetSequence(size_t batch_beam_index); @@ -11,6 +11,10 @@ struct Sequences_Cuda { gpu_span GetNextSequences() { return sequences_next_; } void AppendNextTokenToSequences(std::span next_tokens); + void AppendUserTokensToSequences(gpu_span user_tokens); + + void GetLastTokens(gpu_span& last_tokens); + void RewindTo(size_t index); // Returns current sequence length. int GetSequenceLength() const; diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index e6d9c2a65..7d19a6912 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -102,35 +102,6 @@ TEST(CAPITests, AppendTokensToSequence) { #endif } -TEST(CAPITests, EndToEndPhiBatch) { -#if TEST_PHI2 - auto model = OgaModel::Create(MODEL_PATH "phi-2"); - auto tokenizer = OgaTokenizer::Create(*model); - - const char* input_strings[] = { - "This is a test.", - "Rats are awesome pets!", - "The quick brown fox jumps over the lazy dog.", - }; - - auto input_sequences = OgaSequences::Create(); - for (auto& string : input_strings) - tokenizer->Encode(string, *input_sequences); - - auto params = OgaGeneratorParams::Create(*model); - params->SetSearchOption("max_length", 20); - params->SetInputSequences(*input_sequences); - - auto output_sequences = model->Generate(*params); - - // Decode The Batch - for (size_t i = 0; i < output_sequences->Count(); i++) { - auto out_string = tokenizer->Decode(output_sequences->Get(i)); - std::cout << "Decoded string:" << out_string << std::endl; - } -#endif -} - TEST(CAPITests, Tensor_And_AddExtraInput) { // Create a [3 4] shaped tensor std::array data{0, 1, 2, 3, @@ -176,15 +147,13 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { // And copy the resulting gpt2_init_past_fp32.onnx file into these two files (as it's the same for gpt2) auto model = OgaModel::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); - auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", max_length); - params->SetInputIDs(input_ids.data(), input_ids.size(), input_sequence_length, batch_size); + params->SetSearchOption("batch_size", batch_size); auto generator = OgaGenerator::Create(*model, *params); - + generator->AddInputTokens(input_ids.data(), input_ids.size()); while (!generator->IsDone()) { - generator->ComputeLogits(); generator->GenerateNextToken(); } @@ -198,20 +167,6 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { const auto* expected_output_start = &expected_output[i * max_length]; EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t))); } - - // Test high level API - auto sequences = model->Generate(*params); - - // Verify outputs match expected outputs - for (int i = 0; i < batch_size; i++) { - const auto sequence_length = sequences->SequenceCount(i); - const auto* sequence_data = sequences->SequenceData(i); - - ASSERT_LE(sequence_length, max_length); - - const auto* expected_output_start = &expected_output[i * max_length]; - EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t))); - } } #endif @@ -231,9 +186,9 @@ TEST(CAPITests, GetOutputCAPI) { auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", max_length); - params->SetInputIDs(input_ids.data(), input_ids.size(), input_sequence_length, batch_size); auto generator = OgaGenerator::Create(*model, *params); + generator->AddInputTokens(input_ids.data(), input_ids.size()); // check prompt // full logits has shape [2, 4, 1000]. Sample 1 for every 200 tokens and the expected sampled logits has shape [2, 4, 5] @@ -246,7 +201,6 @@ TEST(CAPITests, GetOutputCAPI) { -0.04699047f, 0.17915794f, 0.20838135f, 0.10888482f, -0.00277808f, 0.2938929f, -0.10538938f, -0.00226692f, 0.12050669f, -0.10622668f}; - generator->ComputeLogits(); auto prompt_logits_ptr = generator->GetOutput("logits"); auto prompt_logits = static_cast(prompt_logits_ptr->Data()); int num_prompt_outputs_to_check = 40; @@ -257,13 +211,13 @@ TEST(CAPITests, GetOutputCAPI) { EXPECT_NEAR(expected_sampled_logits_prompt[i], prompt_logits[i*sample_size], tolerance); } + generator->GenerateNextToken(); generator->GenerateNextToken(); // check for the 1st token generation // full logits has shape [2, 1, 1000]. Sample 1 for every 200 tokens and the expected sampled logits has shape [2, 1, 5] std::vector expected_sampled_logits_token_gen{0.03742531f, -0.05752287f, 0.14159015f, 0.04210977f, -0.1484456f, 0.3041716f, -0.08701379f, -0.03778192f, 0.07471392f, -0.02049096f}; - generator->ComputeLogits(); auto token_gen_logits_ptr = generator->GetOutput("logits"); auto token_gen_logits = static_cast(token_gen_logits_ptr->Data()); int num_token_gen_outputs_to_check = 10; @@ -271,7 +225,6 @@ TEST(CAPITests, GetOutputCAPI) { for (int i = 0; i < num_token_gen_outputs_to_check; i++) { EXPECT_NEAR(expected_sampled_logits_token_gen[i], token_gen_logits[i*sample_size], tolerance); } - generator->GenerateNextToken(); } #if TEST_PHI2 @@ -293,7 +246,6 @@ struct Phi2Test { tokenizer_->Encode(string, *input_sequences_); params_ = OgaGeneratorParams::Create(*model_); - params_->SetInputSequences(*input_sequences_); params_->SetSearchOption("max_length", 40); } @@ -301,9 +253,9 @@ struct Phi2Test { // Low level loop { auto generator = OgaGenerator::Create(*model_, *params_); + generator->AddInputSequences(input_sequences_); while (!generator->IsDone()) { - generator->ComputeLogits(); generator->GenerateNextToken(); } @@ -313,17 +265,6 @@ struct Phi2Test { std::cout << "Decoded string:" << out_string << std::endl; } } - - // High level - { - auto output_sequences = model_->Generate(*params_); - - // Decode The Batch - for (size_t i = 0; i < output_sequences->Count(); i++) { - auto out_string = tokenizer_->Decode(output_sequences->Get(i)); - std::cout << "Decoded string:" << out_string << std::endl; - } - } } std::unique_ptr model_; diff --git a/test/model_tests.cpp b/test/model_tests.cpp index 6766fb892..88b843862 100644 --- a/test/model_tests.cpp +++ b/test/model_tests.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #ifndef MODEL_PATH #define MODEL_PATH "../../test/test_models/" #endif @@ -37,19 +38,17 @@ TEST(ModelTests, GreedySearchGptFp32) { auto params = Generators::CreateGeneratorParams(*model); params->search.max_length = 10; - params->batch_size = static_cast(input_ids_shape[0]); - params->sequence_length = static_cast(input_ids_shape[1]); - params->input_ids = input_ids; + params->search.batch_size = static_cast(input_ids_shape[0]); auto generator = Generators::CreateGenerator(*model, *params); + generator->AddTokens(Generators::cpu_span(input_ids.data(), input_ids.size())); while (!generator->IsDone()) { - generator->ComputeLogits(); generator->GenerateNextToken(); } // Verify outputs match expected outputs - for (size_t i = 0; i < static_cast(params->batch_size); i++) { + for (size_t i = 0; i < static_cast(params->search.batch_size); i++) { auto sequence = generator->GetSequence(i).GetCPU(); auto* expected_output_start = &expected_output[i * params->search.max_length]; EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence.data(), params->search.max_length * sizeof(int32_t))); @@ -76,20 +75,21 @@ TEST(ModelTests, BeamSearchGptFp32) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); auto params = Generators::CreateGeneratorParams(*model); - params->batch_size = static_cast(input_ids_shape[0]); - params->sequence_length = static_cast(input_ids_shape[1]); - params->input_ids = input_ids; + params->search.batch_size = static_cast(input_ids_shape[0]); params->search.max_length = 20; params->search.length_penalty = 1.0f; params->search.num_beams = 4; auto generator = Generators::CreateGenerator(*model, *params); - auto result = Generators::Generate(*model, *params); + generator->AddTokens(Generators::cpu_span(input_ids.data(), input_ids.size())); + while (!generator->IsDone()) { + generator->GenerateNextToken(); + } // Verify outputs match expected outputs - for (int i = 0; i < params->batch_size; i++) { - auto sequence = std::span(result[i].data(), params->search.max_length); + for (int i = 0; i < params->search.batch_size; i++) { + auto sequence = generator->GetSequence(i).GetCPU(); auto* expected_output_start = &expected_output[static_cast(i) * params->search.max_length]; EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence.data(), params->search.max_length * sizeof(int32_t))); } @@ -109,20 +109,18 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label) auto model = Generators::CreateModel(Generators::GetOrtEnv(), model_path); auto params = Generators::CreateGeneratorParams(*model); - params->batch_size = static_cast(input_ids_shape[0]); - params->sequence_length = static_cast(input_ids_shape[1]); + params->search.batch_size = static_cast(input_ids_shape[0]); params->search.max_length = 10; - params->input_ids = input_ids; auto generator = Generators::CreateGenerator(*model, *params); + generator->AddTokens(Generators::cpu_span(input_ids.data(), input_ids.size())); while (!generator->IsDone()) { - generator->ComputeLogits(); generator->GenerateNextToken(); } // Verify outputs match expected outputs - for (int i = 0; i < params->batch_size; i++) { + for (int i = 0; i < params->search.batch_size; i++) { auto sequence_gpu = generator->GetSequence(i); auto sequence = sequence_gpu.GetCPU(); auto* expected_output_start = &expected_output[i * params->search.max_length]; @@ -154,20 +152,21 @@ void Test_BeamSearch_Gpt_Cuda(const char* model_path, const char* model_label) { auto model = Generators::CreateModel(Generators::GetOrtEnv(), model_path); auto params = Generators::CreateGeneratorParams(*model); - params->batch_size = static_cast(input_ids_shape[0]); - params->sequence_length = static_cast(input_ids_shape[1]); - params->input_ids = input_ids; + params->search.batch_size = static_cast(input_ids_shape[0]); params->search.max_length = 20; params->search.num_beams = 4; params->search.length_penalty = 1.0f; auto generator = Generators::CreateGenerator(*model, *params); - auto result = Generators::Generate(*model, *params); + generator->AddTokens(Generators::cpu_span(input_ids.data(), input_ids.size())); + while (!generator->IsDone()) { + generator->GenerateNextToken(); + } // Verify outputs match expected outputs - for (int i = 0; i < params->batch_size; i++) { - auto sequence = std::span(result[i].data(), params->search.max_length); + for (int i = 0; i < params->search.batch_size; i++) { + auto sequence = generator->GetSequence(i).GetCPU(); auto* expected_output_start = &expected_output[static_cast(i) * params->search.max_length]; EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence.data(), params->search.max_length * sizeof(int32_t))); } @@ -195,15 +194,13 @@ Print all primes between 1 and n auto tokens = tokenizer->Encode(prompt); auto params = Generators::CreateGeneratorParams(*model); - params->batch_size = 1; - params->sequence_length = static_cast(tokens.size()); - params->input_ids = tokens; + params->search.batch_size = 1; params->search.max_length = 128; // Generator version auto generator = Generators::CreateGenerator(*model, *params); + generator->AddInputTokens(Generators::cpu_span(tokens.data(), tokens.size())); while (!generator->IsDone()) { - generator->ComputeLogits(); generator->GenerateNextToken(); } @@ -213,32 +210,4 @@ Print all primes between 1 and n #endif } -TEST(ModelTests, TestHighLevelApiCuda) { -#if TEST_PHI2 - auto prompt = R"( -def print_prime(n): -''' -Print all primes between 1 and n -''' -)"; - - std::cout << "With prompt:" << prompt << "\r\n"; - - auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "phi-2"); - auto tokenizer = model->CreateTokenizer(); - auto tokens = tokenizer->Encode(prompt); - - auto params = Generators::CreateGeneratorParams(*model); - params->batch_size = 1; - params->sequence_length = static_cast(tokens.size()); - params->input_ids = tokens; - params->search.max_length = 128; - - // High level version - auto result = Generators::Generate(*model, *params); - - std::cout << tokenizer->Decode(result[0]) << "\r\n"; -#endif -} - #endif \ No newline at end of file diff --git a/test/sampling_benchmark.cpp b/test/sampling_benchmark.cpp index 59cfa61c0..5bbf4553b 100644 --- a/test/sampling_benchmark.cpp +++ b/test/sampling_benchmark.cpp @@ -27,9 +27,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPCpu) { auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; @@ -64,9 +62,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopKCpu) { auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; @@ -104,9 +100,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPAndKCpu) { auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; @@ -145,9 +139,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPCuda) { auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CUDA; std::vector cpu_logits(config.model.vocab_size * batch_size); std::random_device rd; @@ -191,9 +183,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopKCuda) { auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocArray(config.model.vocab_size * batch_size); @@ -234,9 +224,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPAndKCuda) { auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocArray(config.model.vocab_size * batch_size); @@ -279,9 +267,7 @@ TEST(Benchmarks, BenchmarkRandomizedSelectTopCuda) { auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocArray(config.model.vocab_size * batch_size); diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 7a6505255..281286abd 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -29,9 +29,7 @@ TEST(SamplingTests, BatchedSamplingTopPCpu) { params->search.max_length = 10; params->search.do_sample = true; params->search.top_p = 0.25f; - params->batch_size = 4; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = 4; params->device_type = Generators::DeviceType::CUDA; auto generator = Generators::CreateGenerator(*model, *params); auto logits_span = Generators::cpu_span(logits_cpu); @@ -58,9 +56,7 @@ TEST(SamplingTests, BatchedSamplingTopKCpu) { params->search.max_length = 10; params->search.do_sample = true; params->search.top_k = 2; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CPU; auto generator = Generators::CreateGenerator(*model, *params); auto logits_copy = logits_cpu; @@ -94,9 +90,7 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCpu) { params->search.do_sample = true; params->search.top_k = 2; params->search.top_p = 0.25f; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CPU; auto generator = Generators::CreateGenerator(*model, *params); auto logits_copy = logits_cpu; @@ -146,9 +140,7 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) { params->search.max_length = 10; params->search.do_sample = true; params->search.top_p = 0.95f; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; @@ -186,9 +178,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { params->search.max_length = 10; params->search.do_sample = true; params->search.top_k = k; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; @@ -228,9 +218,7 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) { params->search.do_sample = true; params->search.top_k = k; params->search.top_p = p; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(config.model.vocab_size * batch_size); std::random_device rd; @@ -277,9 +265,7 @@ TEST(SamplingTests, BatchedSamplingTopPCuda) { params->search.max_length = 10; params->search.do_sample = true; params->search.top_p = 0.25f; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CUDA; cudaMemcpyAsync(logits_gpu.get(), logits_cpu.data(), logits_cpu.size() * sizeof(float), cudaMemcpyHostToDevice, params->cuda_stream); cudaStreamSynchronize(params->cuda_stream); @@ -309,9 +295,7 @@ TEST(SamplingTests, BatchedSamplingTopKCuda) { params->search.max_length = 10; params->search.do_sample = true; params->search.top_k = 2; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CUDA; cudaMemcpyAsync(logits_gpu.get(), logits_cpu.data(), logits_cpu.size() * sizeof(float), cudaMemcpyHostToDevice, params->cuda_stream); cudaStreamSynchronize(params->cuda_stream); @@ -346,9 +330,7 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCuda) { params->search.do_sample = true; params->search.top_k = 2; params->search.top_p = 0.25f; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CUDA; cudaMemcpyAsync(logits_gpu.get(), logits_cpu.data(), logits_cpu.size() * sizeof(float), cudaMemcpyHostToDevice, params->cuda_stream); cudaStreamSynchronize(params->cuda_stream); @@ -377,9 +359,7 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) { params->search.max_length = 10; params->search.do_sample = true; params->search.top_p = 0.95f; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocHostArray(config.model.vocab_size * batch_size); @@ -422,9 +402,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { params->search.max_length = 10; params->search.do_sample = true; params->search.top_k = k; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocHostArray(config.model.vocab_size * batch_size); @@ -468,9 +446,7 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) { params->search.do_sample = true; params->search.top_k = k; params->search.top_p = p; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocHostArray(config.model.vocab_size * batch_size); @@ -509,9 +485,7 @@ TEST(SamplingTests, RandomizedSamplingSelectTopCuda) { auto params = Generators::CreateGeneratorParams(config); params->search.max_length = 10; - params->batch_size = batch_size; - params->sequence_length = 1; - params->input_ids = input_ids; + params->search.batch_size = batch_size; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(config.model.vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocHostArray(config.model.vocab_size * batch_size);