Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add performace metrics into phi3 C example. #928

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 96 additions & 2 deletions examples/c/src/phi3.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,68 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <cassert>
#include <chrono>
#include <iomanip>
#include <iostream>
#include <string>
#include "ort_genai.h"

using Clock = std::chrono::high_resolution_clock;
using TimePoint = std::chrono::time_point<Clock>;
using Duration = std::chrono::duration<double>;

// `Timing` is a utility class for measuring performance metrics.
class Timing {
public:
Timing(const Timing&) = delete;
Timing& operator=(const Timing&) = delete;

Timing() = default;

~Timing() = default;

void RecordStartTimestamp() {
assert(start_timestamp_.time_since_epoch().count() == 0);
start_timestamp_ = Clock::now();
}

void RecordFirstTokenTimestamp() {
assert(first_token_timestamp_.time_since_epoch().count() == 0);
first_token_timestamp_ = Clock::now();
}

void RecordEndTimestamp() {
assert(end_timestamp_.time_since_epoch().count() == 0);
end_timestamp_ = Clock::now();
}

void Log(const int prompt_tokens_length, const int new_tokens_length) {
assert(start_timestamp_.time_since_epoch().count() != 0);
assert(first_token_timestamp_.time_since_epoch().count() != 0);
assert(end_timestamp_.time_since_epoch().count() != 0);

Duration prompt_time = (first_token_timestamp_ - start_timestamp_);
Duration run_time = (end_timestamp_ - first_token_timestamp_);

const auto default_precision{std::cout.precision()};
std::cout << std::endl;
std::cout << "-------------" << std::endl;
std::cout << std::fixed << std::showpoint << std::setprecision(2)
<< "Prompt length: " << prompt_tokens_length << ", New tokens: " << new_tokens_length
<< ", Time to first: " << prompt_time.count() << "s"
<< ", Prompt tokens per second: " << prompt_tokens_length / prompt_time.count() << " tps"
<< ", New tokens per second: " << new_tokens_length / run_time.count() << " tps"
<< std::setprecision(default_precision) << std::endl;
std::cout << "-------------" << std::endl;
}

private:
TimePoint start_timestamp_;
TimePoint first_token_timestamp_;
TimePoint end_timestamp_;
};

// C++ API Example

void CXX_API(const char* model_path) {
Expand All @@ -16,11 +74,19 @@ void CXX_API(const char* model_path) {

while (true) {
std::string text;
std::cout << "Prompt: " << std::endl;
std::cout << "Prompt: (Use quit() to exit)" << std::endl;
std::getline(std::cin, text);

if (text == "quit()") {
break; // Exit the loop
}

const std::string prompt = "<|user|>\n" + text + "<|end|>\n<|assistant|>";

bool is_first_token = true;
Timing timing;
timing.RecordStartTimestamp();

auto sequences = OgaSequences::Create();
tokenizer->Encode(prompt.c_str(), *sequences);

Expand All @@ -35,6 +101,11 @@ void CXX_API(const char* model_path) {
generator->ComputeLogits();
generator->GenerateNextToken();

if (is_first_token) {
timing.RecordFirstTokenTimestamp();
is_first_token = false;
}

// Show usage of GetOutput
std::unique_ptr<OgaTensor> output_logits = generator->GetOutput("logits");

Expand All @@ -53,6 +124,11 @@ void CXX_API(const char* model_path) {
std::cout << tokenizer_stream->Decode(new_token) << std::flush;
}

timing.RecordEndTimestamp();
const int prompt_tokens_length = sequences->SequenceCount(0);
const int new_tokens_length = generator->GetSequenceCount(0) - prompt_tokens_length;
timing.Log(prompt_tokens_length, new_tokens_length);

for (int i = 0; i < 3; ++i)
std::cout << std::endl;
}
Expand Down Expand Up @@ -82,11 +158,19 @@ void C_API(const char* model_path) {

while (true) {
std::string text;
std::cout << "Prompt: " << std::endl;
std::cout << "Prompt: (Use quit() to exit)" << std::endl;
std::getline(std::cin, text);

if (text == "quit()") {
break; // Exit the loop
}

const std::string prompt = "<|user|>\n" + text + "<|end|>\n<|assistant|>";

bool is_first_token = true;
Timing timing;
timing.RecordStartTimestamp();

OgaSequences* sequences;
CheckResult(OgaCreateSequences(&sequences));
CheckResult(OgaTokenizerEncode(tokenizer, prompt.c_str(), sequences));
Expand All @@ -104,13 +188,23 @@ void C_API(const char* model_path) {
CheckResult(OgaGenerator_ComputeLogits(generator));
CheckResult(OgaGenerator_GenerateNextToken(generator));

if (is_first_token) {
timing.RecordFirstTokenTimestamp();
is_first_token = false;
}

const int32_t num_tokens = OgaGenerator_GetSequenceCount(generator, 0);
int32_t new_token = OgaGenerator_GetSequenceData(generator, 0)[num_tokens - 1];
const char* new_token_string;
CheckResult(OgaTokenizerStreamDecode(tokenizer_stream, new_token, &new_token_string));
std::cout << new_token_string << std::flush;
}

timing.RecordEndTimestamp();
const int prompt_tokens_length = OgaSequencesGetSequenceCount(sequences, 0);
const int new_tokens_length = OgaGenerator_GetSequenceCount(generator, 0) - prompt_tokens_length;
timing.Log(prompt_tokens_length, new_tokens_length);

for (int i = 0; i < 3; ++i)
std::cout << std::endl;

Expand Down