Skip to content

Commit

Permalink
Address PR review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Sep 16, 2024
1 parent 028a792 commit c4f4ed3
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 22 deletions.
19 changes: 0 additions & 19 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,25 +183,6 @@ std::vector<std::string> Tokenizer::DecodeBatch(std::span<const int32_t> sequenc
return strings;
}

std::vector<int32_t> Tokenizer::GetDecoderPromptIds(size_t batch_size, const std::string& language,
const std::string& task, int32_t no_timestamps) const {
ort_extensions::OrtxObjectPtr<OrtxTokenId2DArray> prompt_ids;
CheckResult(OrtxGetDecoderPromptIds(tokenizer_, batch_size, language.c_str(),
task.c_str(), no_timestamps, ort_extensions::ptr(prompt_ids)));

std::vector<std::vector<int32_t>> tokens_vector;
std::vector<std::span<const int32_t>> span_sequences;
for (size_t i = 0; i < batch_size; i++) {
const extTokenId_t* tokens = nullptr;
size_t token_count = 0;
CheckResult(OrtxTokenId2DArrayGetItem(prompt_ids.get(), i, &tokens, &token_count));
tokens_vector.emplace_back(tokens, tokens + token_count);
span_sequences.emplace_back(tokens_vector.back());
}

return PadInputs(span_sequences, pad_token_id_);
}

int32_t Tokenizer::TokenToTokenId(const char* token) const {
extTokenId_t token_id;
CheckResult(OrtxConvertTokenToId(tokenizer_, token, &token_id));
Expand Down
3 changes: 0 additions & 3 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ struct Tokenizer : std::enable_shared_from_this<Tokenizer>, LeakChecked<Tokenize
std::vector<int32_t> EncodeBatch(std::span<const std::string> strings) const;
std::vector<std::string> DecodeBatch(std::span<const int32_t> sequences, size_t count) const;

std::vector<int32_t> GetDecoderPromptIds(size_t batch_size, const std::string& language,
const std::string& task, int32_t no_timestamps) const;

int32_t TokenToTokenId(const char* token) const;

OrtxPtr<OrtxTokenizer> tokenizer_;
Expand Down

0 comments on commit c4f4ed3

Please sign in to comment.