Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zhihao committed Oct 15, 2024
1 parent c46ddc0 commit 2b5a023
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
2 changes: 1 addition & 1 deletion python/flexflow/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def download_hf_tokenizer_if_needed(self):
f"'{self.model_name}' tokenizer needs updating! Downloading tokenizer now..."
)
# Load/download the tokenizer files
target_tokenizer_files = ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]
target_tokenizer_files = ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "vocab.json", "merges.txt"]
if os.path.exists(self.model_name):
hf_tokenizer_path = self.model_name
else:
Expand Down
46 changes: 39 additions & 7 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,19 +186,27 @@ void RequestManager::register_tokenizer(ModelType type,
std::filesystem::path tokenizer_folder(path);

if (model_type == ModelType::LLAMA) {
// try with tokenizer.json first
std::filesystem::path tokenizer_json_path;
if (std::filesystem::is_directory(tokenizer_folder)) {
tokenizer_json_path =
std::filesystem::path(tokenizer_folder) / "tokenizer.json";
tokenizer_json_path = std::filesystem::path(tokenizer_folder) / "tokenizer.json";
} else {
tokenizer_json_path = tokenizer_folder;
}
if (!std::filesystem::exists(tokenizer_json_path)) {
std::cerr << "Failed to open file: " << tokenizer_json_path << std::endl;
assert(false);
if (std::filesystem::exists(tokenizer_json_path)) {
// load from tokenizer.json
this->tokenizer_ = Tokenizer::FromBlobJSON(LoadBytesFromFile(tokenizer_json_path.string()));
} else {
// load from tokenizer.model
std::filesystem::path tokenizer_model_path =
tokenizer_folder / "tokenizer.model";
if (!std::filesystem::exists(tokenizer_model_path)) {
std::cerr << "Failed to open file: " << tokenizer_model_path
<< std::endl;
assert(false);
}
this->tokenizer_ = Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(tokenizer_model_path.string()));
}
this->tokenizer_ = Tokenizer::FromBlobJSON(
LoadBytesFromFile(tokenizer_json_path.string()));
} else if (model_type == ModelType::OPT) {
std::filesystem::path vocab_file = tokenizer_folder / "vocab.json";
std::filesystem::path merges_file = tokenizer_folder / "merges.txt";
Expand Down Expand Up @@ -648,6 +656,12 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
bool request_completed = check_inf_req_completion(old_bc, i);
if (request_completed) {
std::string output = this->tokenizer_->Decode(request.tokens);
// Unlike Huggingface, the sentencepiece C++ library automatically
// removes the BOS token
if (model_type == ModelType::LLAMA &&
request.tokens.at(0) == bos_token_id) {
output = "<s> " + output;
}
{
// update generation result
GenerationResult &gr = request_generation_results[request.guid];
Expand Down Expand Up @@ -1103,6 +1117,12 @@ BeamSearchBatchConfig
request.guid,
request.tokens.size());
std::string output = this->tokenizer_->Decode(request.tokens);
// Unlike Huggingface, the sentencepiece C++ library automatically
// removes the BOS token
if (model_type == ModelType::LLAMA &&
request.tokens.at(0) == bos_token_id) {
output = "<s> " + output;
}
{
// update generation result
GenerationResult &gr = request_generation_results[request.guid];
Expand Down Expand Up @@ -1240,6 +1260,12 @@ BeamSearchBatchConfig
}

std::string output = this->tokenizer_->Decode(request.tokens);
// Unlike Huggingface, the sentencepiece C++ library automatically
// removes the BOS token
if (model_type == ModelType::LLAMA &&
request.tokens.at(0) == bos_token_id) {
output = "<s> " + output;
}
log_req_mgr.print("Output: %s", output.c_str());
}

Expand Down Expand Up @@ -1282,6 +1308,12 @@ BeamSearchBatchConfig

// Token Info
std::string output = this->tokenizer_->Decode(request.tokens);
// Unlike Huggingface, the sentencepiece C++ library automatically removes
// the BOS token
if (model_type == ModelType::LLAMA &&
request.tokens.at(0) == bos_token_id) {
output = "<s> " + output;
}
log_req_mgr.print("Output: %s", output.c_str());
} else {
assert(false);
Expand Down

0 comments on commit 2b5a023

Please sign in to comment.