Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aciddelgado/continuous #867

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
30 changes: 16 additions & 14 deletions benchmark/c/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,15 @@ std::string GeneratePrompt(size_t num_prompt_tokens, const OgaModel& model, cons
auto params = OgaGeneratorParams::Create(model);
params->SetSearchOption("max_length", static_cast<double>(num_prompt_tokens));
params->SetSearchOption("min_length", static_cast<double>(num_prompt_tokens));
params->SetInputSequences(*base_prompt_sequences);

auto output_sequences = model.Generate(*params);
const auto output_sequence_length = output_sequences->SequenceCount(0);
const auto* output_sequence_data = output_sequences->SequenceData(0);
auto generator = OgaGenerator::Create(model, *params);
generator->AddInputSequences(*base_prompt_sequences);
while (!generator->IsDone()) {
generator->GenerateNextToken();
}

const auto output_sequence_length = generator->GetSequenceCount(0);
const auto* output_sequence_data = generator->GetSequenceData(0);
return std::string{tokenizer.Decode(output_sequence_data, output_sequence_length)};
}

Expand All @@ -151,7 +155,6 @@ void RunBenchmark(const benchmark::Options& opts) {
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", static_cast<double>(num_tokens));
params->SetSearchOption("min_length", static_cast<double>(num_tokens));
params->SetInputSequences(*prompt_sequences);
return params;
};

Expand All @@ -160,13 +163,17 @@ void RunBenchmark(const benchmark::Options& opts) {
// warmup
if (opts.verbose) std::cout << "Running warmup iterations (" << opts.num_warmup_iterations << ")...\n";
for (size_t i = 0; i < opts.num_warmup_iterations; ++i) {
auto output_sequences = model->Generate(*generator_params);
auto generator = OgaGenerator::Create(*model, *generator_params);
generator->AddInputSequences(*prompt_sequences);
while (!generator->IsDone()) {
generator->GenerateNextToken();
}

if (opts.verbose && i == 0) {
// show prompt and output on first iteration
std::cout << "Prompt:\n\t" << prompt << "\n";
const auto output_sequence_length = output_sequences->SequenceCount(0);
const auto* output_sequence_data = output_sequences->SequenceData(0);
const auto output_sequence_length = generator->GetSequenceCount(0);
const auto* output_sequence_data = generator->GetSequenceData(0);
const auto output = tokenizer->Decode(output_sequence_data, output_sequence_length);
std::cout << "Output:\n\t" << output << "\n";
}
Expand All @@ -188,7 +195,7 @@ void RunBenchmark(const benchmark::Options& opts) {

{
Timing prompt_processing_timing{prompt_processing_times};
generator->ComputeLogits();
generator->AddInputSequences(*prompt_sequences);
}

{
Expand All @@ -199,11 +206,6 @@ void RunBenchmark(const benchmark::Options& opts) {
while (!generator->IsDone()) {
{
Timing token_gen_timing{token_gen_times};
generator->ComputeLogits();
}

{
Timing sampling_timing{sampling_times};
generator->GenerateNextToken();
}
}
Expand Down
6 changes: 2 additions & 4 deletions examples/c/src/phi3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,11 @@ void CXX_API(const char* model_path) {
std::cout << "Generating response..." << std::endl;
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 1024);
params->SetInputSequences(*sequences);

auto generator = OgaGenerator::Create(*model, *params);
generator->AddInputSequences(*sequences);

while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken();

if (is_first_token) {
Expand Down Expand Up @@ -179,13 +178,12 @@ void C_API(const char* model_path) {
OgaGeneratorParams* params;
CheckResult(OgaCreateGeneratorParams(model, &params));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 1024));
CheckResult(OgaGeneratorParamsSetInputSequences(params, sequences));

OgaGenerator* generator;
CheckResult(OgaCreateGenerator(model, params, &generator));
CheckResult(OgaGenerator_AddInputSequences(generator, sequences));

while (!OgaGenerator_IsDone(generator)) {
CheckResult(OgaGenerator_ComputeLogits(generator));
CheckResult(OgaGenerator_GenerateNextToken(generator));

if (is_first_token) {
Expand Down
2 changes: 0 additions & 2 deletions examples/c/src/phi3v.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ void CXX_API(const char* model_path) {
auto generator = OgaGenerator::Create(*model, *params);

while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken();

const auto num_tokens = generator->GetSequenceCount(0);
Expand Down Expand Up @@ -162,7 +161,6 @@ void C_API(const char* model_path) {
CheckResult(OgaCreateGenerator(model, params, &generator));

while (!OgaGenerator_IsDone(generator)) {
CheckResult(OgaGenerator_ComputeLogits(generator));
CheckResult(OgaGenerator_GenerateNextToken(generator));

const int32_t num_tokens = OgaGenerator_GetSequenceCount(generator, 0);
Expand Down
21 changes: 14 additions & 7 deletions examples/python/model-generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def main(args):
if hasattr(args, 'prompts'):
prompts = args.prompts
else:
prompts = ["I like walking my cute dog",
"What is the best restaurant in town?",
"Hello, how are you today?"]
prompts = ["The first 4 digits of pi are",
"The square root of 2 is",
"The first 6 numbers of the Fibonacci sequence are",]

if args.chat_template:
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
Expand All @@ -28,6 +28,7 @@ def main(args):
params = og.GeneratorParams(model)

search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}
search_options['batch_size'] = 3

if (args.verbose): print(f'Args: {args}')
if (args.verbose): print(f'Search options: {search_options}')
Expand All @@ -37,22 +38,28 @@ def main(args):
params.try_graph_capture_with_max_batch_size(len(prompts))
if args.batch_size_for_cuda_graph:
params.try_graph_capture_with_max_batch_size(args.batch_size_for_cuda_graph)
params.input_ids = input_tokens
if args.verbose: print("GeneratorParams created")

generator = og.Generator(model, params)
if args.verbose: print("Generator created")

generator.add_input_tokens(input_tokens)
if args.verbose: print("Input tokens added")

if args.verbose: print("Generating tokens ...\n")
start_time = time.time()
output_tokens = model.generate(params)
while not generator.is_done():
generator.generate_next_token()
run_time = time.time() - start_time

for i in range(len(prompts)):
print(f'Prompt #{i}: {prompts[i]}')
print()
print(tokenizer.decode(output_tokens[i]))
print(tokenizer.decode(generator.get_sequence(i)))
print()

print()
total_tokens = sum(len(x) for x in output_tokens)
total_tokens = sum(len(generator.get_sequence(i)) for i in range(len(prompts)))
print(f"Tokens: {total_tokens} Time: {run_time:.2f} Tokens per second: {total_tokens/run_time:.2f}")
print()

Expand Down
29 changes: 19 additions & 10 deletions examples/python/model-qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def main(args):
if args.verbose: print()

search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}
search_options['batch_size'] = 1

if args.verbose: print(search_options)

Expand All @@ -24,6 +25,16 @@ def main(args):
print("Error, chat template must have exactly one pair of curly braces, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
exit(1)

params = og.GeneratorParams(model)
params.set_search_options(**search_options)
generator = og.Generator(model, params)

# Set system prompt
system_prompt = args.system_prompt
system_tokens = tokenizer.encode(system_prompt)
generator.add_input_tokens(system_tokens)
system_prompt_length = len(system_tokens)

# Keep asking for input prompts in a loop
while True:
text = input("Input: ")
Expand All @@ -39,11 +50,8 @@ def main(args):
prompt = f'{args.chat_template.format(input=text)}'

input_tokens = tokenizer.encode(prompt)

params = og.GeneratorParams(model)
params.set_search_options(**search_options)
params.input_ids = input_tokens
generator = og.Generator(model, params)

generator.add_input_tokens(input_tokens)
if args.verbose: print("Generator created")

if args.verbose: print("Running generation loop ...")
Expand All @@ -56,7 +64,6 @@ def main(args):

try:
while not generator.is_done():
generator.compute_logits()
generator.generate_next_token()
if args.timings:
if first:
Expand All @@ -71,14 +78,14 @@ def main(args):
print()
print()

# Delete the generator to free the captured graph for the next generator, if graph capture is enabled
del generator

if args.timings:
prompt_time = first_token_timestamp - started_timestamp
run_time = time.time() - first_token_timestamp
print(f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps")


# Rewind the generator to the system prompt
if args.rewind:
generator.rewind_to_length(system_prompt_length)

if __name__ == "__main__":
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai")
Expand All @@ -93,5 +100,7 @@ def main(args):
parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false')
parser.add_argument('-g', '--timings', action='store_true', default=False, help='Print timing information for each generation step. Defaults to false')
parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}')
parser.add_argument('-s', '--system_prompt', type=str, default='You are a helpful assistant. You are friendly, courteous, and professional. All your responses must end with an exclamation point!', help='System prompt to use for the prompt.')
parser.add_argument('-re', '--rewind', action='store_true', default=False, help='Rewind to the system prompt after each generation. Defaults to false')
args = parser.parse_args()
main(args)
11 changes: 7 additions & 4 deletions src/beam_search_scorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_length) con
}

BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters)
: batch_size_{parameters.batch_size},
: batch_size_{parameters.search.batch_size},
num_beams_{parameters.search.num_beams},
max_length_{parameters.search.max_length},
pad_token_id_{parameters.config.model.pad_token_id},
eos_token_id_{parameters.config.model.eos_token_id},
early_stopping_{parameters.search.early_stopping},
not_done_count_{parameters.batch_size} {
not_done_count_{parameters.search.batch_size} {
size_t const batch_beam_size = static_cast<size_t>(batch_size_) * num_beams_;

std::span<HypothesisScore> beams;
Expand All @@ -65,15 +65,18 @@ BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters)
next_beam_indices_ptr_ = AllocateArray<int32_t>(batch_beam_size, &next_beam_indices_);

// Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length.
size_t const per_beam = (max_length_ * (max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2;
// TODO(aciddelgado): Initialize in first update function type thing.
// size_t const per_beam = (max_length_ * (max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2;
size_t const per_beam = (max_length_ * (max_length_ + 1)) / 2;

hypothesis_buffer_ptr_ = AllocateArray<int32_t>(batch_beam_size * per_beam, &hypothesis_buffer_);

memset(next_beam_scores_.data(), 0, next_beam_scores_.size_bytes());

// Initialize score of first beam of each group with 0 and the rest with -1e9.
// This ensures that the beams in the same group don't produce same tokens every time.
std::span<float> const beam_scores = next_beam_scores_;
for (int i = 0; i < parameters.batch_size; i++) {
for (int i = 0; i < parameters.search.batch_size; i++) {
for (int j = 1; j < parameters.search.num_beams; j++) {
beam_scores[i * parameters.search.num_beams + j] = -1e9;
}
Expand Down
10 changes: 6 additions & 4 deletions src/beam_search_scorer_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ namespace Generators {
BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters)
: stream_{parameters.cuda_stream} {
state_cpu_ = CudaMallocHostArray<cuda::BeamScorerState>(1);
state_cpu_->batch_size_ = static_cast<size_t>(parameters.batch_size);
state_cpu_->batch_size_ = static_cast<size_t>(parameters.search.batch_size);
state_cpu_->num_beams_ = static_cast<size_t>(parameters.search.num_beams);
state_cpu_->max_length_ = static_cast<size_t>(parameters.search.max_length);
state_cpu_->pad_token_id_ = parameters.config.model.pad_token_id;
state_cpu_->eos_token_id_ = parameters.config.model.eos_token_id;
state_cpu_->early_stopping_ = parameters.search.early_stopping;
state_cpu_->not_done_count_ = parameters.batch_size;
state_cpu_->not_done_count_ = parameters.search.batch_size;
state_cpu_->hypothesis_buffer_used_ = 0;
state_gpu_ = CudaMallocArray<cuda::BeamScorerState>(1);
cudaMemcpyAsync(state_gpu_.get(), state_cpu_.get(), sizeof(cuda::BeamScorerState), ::cudaMemcpyHostToDevice, stream_);
Expand All @@ -34,10 +34,12 @@ BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters)
next_beam_indices_cpu_ptr_ = std::make_unique<int32_t[]>(batch_beam_size);
next_beam_indices_cpu_ = cpu_span<int32_t>(next_beam_indices_cpu_ptr_.get(), batch_beam_size);

cuda::LaunchInitScoresKernel(next_beam_scores_.data(), parameters.batch_size, parameters.search.num_beams, stream_);
cuda::LaunchInitScoresKernel(next_beam_scores_.data(), parameters.search.batch_size, parameters.search.num_beams, stream_);

// Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length.
size_t per_beam = (state_cpu_->max_length_ * (state_cpu_->max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2;
// TODO(aciddelgado): Initialize in first update function type thing.
// size_t per_beam = (state_cpu_->max_length_ * (state_cpu_->max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2;
size_t per_beam = (state_cpu_->max_length_ * (state_cpu_->max_length_ + 1)) / 2;
hypothesis_buffer_ptr_ = CudaMallocArray<int32_t>(batch_beam_size * per_beam, &hypothesis_buffer_);
}

Expand Down
2 changes: 2 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,8 @@ struct Search_Element : JSON::Element {
v_.min_length = static_cast<int>(value);
} else if (name == "max_length") {
v_.max_length = static_cast<int>(value);
} else if (name == "batch_size") {
v_.batch_size = static_cast<int>(value);
} else if (name == "num_beams") {
v_.num_beams = static_cast<int>(value);
} else if (name == "num_return_sequences") {
Expand Down
1 change: 1 addition & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ struct Config {
bool do_sample{}; // True to do randomized sampling through top_k and top_p, if false, the top logit score is chosen
int min_length{};
int max_length{}; // If omitted or 0 in json file, will be set to model.context_length on load
int batch_size{1};
int num_beams{1}; // 1 means no beam search.
int num_return_sequences{1};
float repetition_penalty{1.0f}; // 1.0 means no penalty.
Expand Down
Loading
Loading