Skip to content

Commit

Permalink
Address pull-request review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Sep 11, 2024
1 parent e06c46f commit a3ddc38
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 20 deletions.
6 changes: 4 additions & 2 deletions examples/c/src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void CXX_API(const char* model_path, int32_t num_beams) {
const size_t batch_size = audio_paths.size();
for (size_t i = 0; i < batch_size; ++i) {
for (const auto& token : prompt_tokens) {
tokenizer->ToTokenId(token, *input_ids, i);
input_ids->Append(tokenizer->ToTokenId(token), i);
}
}

Expand Down Expand Up @@ -155,7 +155,9 @@ void C_API(const char* model_path, int32_t num_beams) {
const size_t batch_size = audio_paths.size();
for (size_t i = 0; i < batch_size; ++i) {
for (const auto& token : prompt_tokens) {
CheckResult(OgaTokenizerToTokenId(tokenizer, token, input_ids, i));
int32_t token_id;
CheckResult(OgaTokenizerToTokenId(tokenizer, token, &token_id));
CheckResult(OgaAppendTokenToSequence(token_id, input_ids, i));
}
}

Expand Down
1 change: 0 additions & 1 deletion examples/python/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def run(args: argparse.Namespace):
print("Processing audio...")
mel = processor(audios=audios)
decoder_prompt_tokens = ["<|startoftranscript|>", "<|en|>", "<|transcribe|>", "<|notimestamps|>"]
print([tokenizer.to_token_id(token) for token in decoder_prompt_tokens])

params = og.GeneratorParams(model)
params.set_search_options(
Expand Down
11 changes: 9 additions & 2 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ struct OgaSequences : OgaAbstract {
void Append(const int32_t* tokens, size_t token_cnt) {
OgaCheckResult(OgaAppendTokenSequence(tokens, token_cnt, this));
}

void Append(int32_t token, size_t sequence_index) {
OgaCheckResult(OgaAppendTokenToSequence(token, this, sequence_index));
}

#if __cplusplus >= 202002L
std::span<const int32_t> Get(size_t index) const {
return {SequenceData(index), SequenceCount(index)};
Expand All @@ -132,8 +137,10 @@ struct OgaTokenizer : OgaAbstract {
OgaCheckResult(OgaTokenizerEncode(this, str, &sequences));
}

void ToTokenId(const char* str, OgaSequences& sequences, size_t sequence_idx) const {
OgaCheckResult(OgaTokenizerToTokenId(this, str, &sequences, sequence_idx));
int32_t ToTokenId(const char* str) const {
int32_t token_id;
OgaCheckResult(OgaTokenizerToTokenId(this, str, &token_id));
return token_id;
}

OgaString Decode(const int32_t* tokens_data, size_t tokens_length) const {
Expand Down
28 changes: 18 additions & 10 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,22 @@ OgaResult* OGA_API_CALL OgaAppendTokenSequence(const int32_t* token_ptr, size_t
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaAppendTokenToSequence(int32_t token, OgaSequences* sequences, size_t sequence_index) {
OGA_TRY
Generators::TokenSequences* toks = reinterpret_cast<Generators::TokenSequences*>(sequences);
if (sequence_index > toks->size()) {
throw std::runtime_error("sequence index out of bounds");
}
if (sequence_index == toks->size()) {
toks->emplace_back();
}

toks->at(sequence_index).push_back(token);

return nullptr;
OGA_CATCH
}

size_t OGA_API_CALL OgaSequencesCount(const OgaSequences* p) {
return reinterpret_cast<const Generators::TokenSequences*>(p)->size();
}
Expand Down Expand Up @@ -321,18 +337,10 @@ OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer* p, const char* st
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaTokenizerToTokenId(const OgaTokenizer* p, const char* str, OgaSequences* sequences, size_t sequence_idx) {
OgaResult* OGA_API_CALL OgaTokenizerToTokenId(const OgaTokenizer* p, const char* str, int32_t* token_id) {
OGA_TRY
auto& tokenizer = *reinterpret_cast<const Generators::Tokenizer*>(p);
auto& token_sequences = *reinterpret_cast<Generators::TokenSequences*>(sequences);
if (sequence_idx > token_sequences.size())
throw std::runtime_error("sequence_idx is out of bounds");

if (sequence_idx == token_sequences.size())
token_sequences.push_back({});

token_sequences[sequence_idx].push_back(tokenizer.TokenToTokenId(str));

*token_id = tokenizer.TokenToTokenId(str);
return nullptr;
OGA_CATCH
}
Expand Down
19 changes: 14 additions & 5 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ OGA_EXPORT size_t OGA_API_CALL OgaSequencesCount(const OgaSequences* sequences);
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaAppendTokenSequence(const int32_t* token_ptr, size_t token_cnt, OgaSequences* sequence);

/*
* \brief Appends the given token to the sequence at the given index.
If the sequence at the given index does not exist, a new sequence is
created at the given index if sequence_idx is equal to the current sequences count.
* \param[in] token token to append to the sequence
* \param[in] sequences OgaSequences object to append the token to
* \param[in] sequence_index index of the sequence to append the token to
* \return OgaResult containing the error message when tokens could not been added, else nullptr.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaAppendTokenToSequence(int32_t token, OgaSequences* sequence, size_t sequence_index);

/*
* \brief Returns the number of tokens in the sequence at the given index
* \param[in] sequences
Expand Down Expand Up @@ -281,15 +292,13 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyMultiModalProcessor(OgaMultiModalProcesso
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer*, const char* str, OgaSequences* sequences);

/*
* \brief Converts the given string to a single token id and adds it to the end of the sequence at the given index.
If the sequence does not exist, a new sequence is created at the given index if sequence_idx is equal to the current sequences count.
* \brief Converts the given string to a single token id.
* \param[in] tokenizer The tokenizer to use to convert the string to a token id.
* \param[in] str The string to convert to a token id.
* \param[in] sequences The OgaSequences to add the token id to.
* \param[in] sequence_idx The index of the sequence to add the token id to.
* \param[in] token_id The converted token id.
* \return OgaResult containing the error message if the conversion of the string to a token id failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerToTokenId(const OgaTokenizer* tokenizer, const char* str, OgaSequences* sequences, size_t sequence_idx);
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerToTokenId(const OgaTokenizer* tokenizer, const char* str, int32_t* token_id);

OGA_EXPORT OgaResult* OGA_API_CALL OgaProcessorProcessImages(const OgaMultiModalProcessor*, const char* prompt, const OgaImages* images, OgaNamedTensors** input_tensors);

Expand Down

0 comments on commit a3ddc38

Please sign in to comment.