diff --git a/examples/c/src/whisper.cpp b/examples/c/src/whisper.cpp index 2b781b854..63c06b89b 100644 --- a/examples/c/src/whisper.cpp +++ b/examples/c/src/whisper.cpp @@ -63,7 +63,7 @@ void CXX_API(const char* model_path, int32_t num_beams) { const size_t batch_size = audio_paths.size(); for (size_t i = 0; i < batch_size; ++i) { for (const auto& token : prompt_tokens) { - tokenizer->ToTokenId(token, *input_ids, i); + input_ids->Append(tokenizer->ToTokenId(token), i); } } @@ -155,7 +155,9 @@ void C_API(const char* model_path, int32_t num_beams) { const size_t batch_size = audio_paths.size(); for (size_t i = 0; i < batch_size; ++i) { for (const auto& token : prompt_tokens) { - CheckResult(OgaTokenizerToTokenId(tokenizer, token, input_ids, i)); + int32_t token_id; + CheckResult(OgaTokenizerToTokenId(tokenizer, token, &token_id)); + CheckResult(OgaAppendTokenToSequence(token_id, input_ids, i)); } } diff --git a/examples/python/whisper.py b/examples/python/whisper.py index eccd84af9..581f6bd28 100644 --- a/examples/python/whisper.py +++ b/examples/python/whisper.py @@ -41,7 +41,6 @@ def run(args: argparse.Namespace): print("Processing audio...") mel = processor(audios=audios) decoder_prompt_tokens = ["<|startoftranscript|>", "<|en|>", "<|transcribe|>", "<|notimestamps|>"] - print([tokenizer.to_token_id(token) for token in decoder_prompt_tokens]) params = og.GeneratorParams(model) params.set_search_options( diff --git a/src/ort_genai.h b/src/ort_genai.h index 3cbd4324a..36b56d9bd 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -106,6 +106,11 @@ struct OgaSequences : OgaAbstract { void Append(const int32_t* tokens, size_t token_cnt) { OgaCheckResult(OgaAppendTokenSequence(tokens, token_cnt, this)); } + + void Append(int32_t token, size_t sequence_index) { + OgaCheckResult(OgaAppendTokenToSequence(token, this, sequence_index)); + } + #if __cplusplus >= 202002L std::span Get(size_t index) const { return {SequenceData(index), SequenceCount(index)}; @@ -132,8 +137,10 @@ struct OgaTokenizer : OgaAbstract { OgaCheckResult(OgaTokenizerEncode(this, str, &sequences)); } - void ToTokenId(const char* str, OgaSequences& sequences, size_t sequence_idx) const { - OgaCheckResult(OgaTokenizerToTokenId(this, str, &sequences, sequence_idx)); + int32_t ToTokenId(const char* str) const { + int32_t token_id; + OgaCheckResult(OgaTokenizerToTokenId(this, str, &token_id)); + return token_id; } OgaString Decode(const int32_t* tokens_data, size_t tokens_length) const { diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 13b336559..fdb292258 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -70,6 +70,22 @@ OgaResult* OGA_API_CALL OgaAppendTokenSequence(const int32_t* token_ptr, size_t OGA_CATCH } +OgaResult* OGA_API_CALL OgaAppendTokenToSequence(int32_t token, OgaSequences* sequences, size_t sequence_index) { + OGA_TRY + Generators::TokenSequences* toks = reinterpret_cast(sequences); + if (sequence_index > toks->size()) { + throw std::runtime_error("sequence index out of bounds"); + } + if (sequence_index == toks->size()) { + toks->emplace_back(); + } + + toks->at(sequence_index).push_back(token); + + return nullptr; + OGA_CATCH +} + size_t OGA_API_CALL OgaSequencesCount(const OgaSequences* p) { return reinterpret_cast(p)->size(); } @@ -321,18 +337,10 @@ OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer* p, const char* st OGA_CATCH } -OgaResult* OGA_API_CALL OgaTokenizerToTokenId(const OgaTokenizer* p, const char* str, OgaSequences* sequences, size_t sequence_idx) { +OgaResult* OGA_API_CALL OgaTokenizerToTokenId(const OgaTokenizer* p, const char* str, int32_t* token_id) { OGA_TRY auto& tokenizer = *reinterpret_cast(p); - auto& token_sequences = *reinterpret_cast(sequences); - if (sequence_idx > token_sequences.size()) - throw std::runtime_error("sequence_idx is out of bounds"); - - if (sequence_idx == token_sequences.size()) - token_sequences.push_back({}); - - token_sequences[sequence_idx].push_back(tokenizer.TokenToTokenId(str)); - + *token_id = tokenizer.TokenToTokenId(str); return nullptr; OGA_CATCH } diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 870b96c49..c1d03f8e1 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -110,6 +110,17 @@ OGA_EXPORT size_t OGA_API_CALL OgaSequencesCount(const OgaSequences* sequences); */ OGA_EXPORT OgaResult* OGA_API_CALL OgaAppendTokenSequence(const int32_t* token_ptr, size_t token_cnt, OgaSequences* sequence); +/* + * \brief Appends the given token to the sequence at the given index. + If the sequence at the given index does not exist, a new sequence is + created at the given index if sequence_idx is equal to the current sequences count. + * \param[in] token token to append to the sequence + * \param[in] sequences OgaSequences object to append the token to + * \param[in] sequence_index index of the sequence to append the token to + * \return OgaResult containing the error message when tokens could not been added, else nullptr. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaAppendTokenToSequence(int32_t token, OgaSequences* sequence, size_t sequence_index); + /* * \brief Returns the number of tokens in the sequence at the given index * \param[in] sequences @@ -281,15 +292,13 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyMultiModalProcessor(OgaMultiModalProcesso OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer*, const char* str, OgaSequences* sequences); /* - * \brief Converts the given string to a single token id and adds it to the end of the sequence at the given index. - If the sequence does not exist, a new sequence is created at the given index if sequence_idx is equal to the current sequences count. + * \brief Converts the given string to a single token id. * \param[in] tokenizer The tokenizer to use to convert the string to a token id. * \param[in] str The string to convert to a token id. - * \param[in] sequences The OgaSequences to add the token id to. - * \param[in] sequence_idx The index of the sequence to add the token id to. + * \param[in] token_id The converted token id. * \return OgaResult containing the error message if the conversion of the string to a token id failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerToTokenId(const OgaTokenizer* tokenizer, const char* str, OgaSequences* sequences, size_t sequence_idx); +OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerToTokenId(const OgaTokenizer* tokenizer, const char* str, int32_t* token_id); OGA_EXPORT OgaResult* OGA_API_CALL OgaProcessorProcessImages(const OgaMultiModalProcessor*, const char* prompt, const OgaImages* images, OgaNamedTensors** input_tensors);