Skip to content

Commit

Permalink
Fix edge cases with specific prompt lengths.
Browse files Browse the repository at this point in the history
  • Loading branch information
zwang86 committed Jul 26, 2023
1 parent 497a945 commit f5bf9e6
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 13 deletions.
2 changes: 2 additions & 0 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class BatchConfig {
static BatchConfig const *from_future(BatchConfigFuture const &future);
static int const MAX_NUM_REQUESTS = 1;
static int const MAX_NUM_TOKENS = 64;
static int const MAX_PROMPT_LENGTH =
63; // should be MAX_NUM_TOKENS - 1 for SpecInfer
static int const MAX_SEQ_LENGTH = 256;

// These are set by update
Expand Down
49 changes: 42 additions & 7 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,22 @@ RequestManager::RequestGuid
Request request;
request.guid = next_available_guid++;
request.max_sequence_length = max_sequence_length;
request.initial_len = prompt.size();
request.tokens = prompt;

if (prompt.size() > BatchConfig::MAX_PROMPT_LENGTH) {
std::cout << "Warning: too many tokens in prompt, only load up to "
<< BatchConfig::MAX_PROMPT_LENGTH << " tokens, but got "
<< prompt.size() << ".\n";
// Truncate the prompt to MAX_NUM_TOKENS
request.tokens.insert(request.tokens.end(),
prompt.begin(),
prompt.begin() + BatchConfig::MAX_PROMPT_LENGTH);
request.initial_len = BatchConfig::MAX_PROMPT_LENGTH;
printf("tokens size: %zu\n", request.tokens.size());
// assert(false);
} else {
request.initial_len = prompt.size();
request.tokens = prompt;
}

if (get_num_ssms() == 0) {
std::cout << "No small speculative model registered yet, using incremental "
Expand Down Expand Up @@ -175,12 +189,22 @@ RequestManager::RequestGuid
request.tokens.push_back(this->model_bos_map.at(this->model_type));
std::vector<int32_t> tokens = this->tokenizer_->Encode(prompt);

if (tokens.size() > BatchConfig::MAX_PROMPT_LENGTH) {
std::cout << "Warning: too many tokens in prompt, only load up to "
<< BatchConfig::MAX_PROMPT_LENGTH << " tokens, but got "
<< tokens.size() << ".\n";
// Truncate the prompt to MAX_NUM_TOKENS
tokens.resize(BatchConfig::MAX_PROMPT_LENGTH);
printf("tokens size: %zu\n", tokens.size());
// assert(false);
}

for (int i = 0; i < tokens.size(); i++) {
std::cout << tokens.at(i) << "\n";
std::cout << "[" << i << "]" << tokens.at(i) << "\n";
}

// assert(false);
request.tokens.insert(request.tokens.end(), tokens.begin(), tokens.end());
request.tokens = tokens;
request.initial_len = request.tokens.size();

if (get_num_ssms() == 0) {
Expand Down Expand Up @@ -809,6 +833,7 @@ BeamSearchBatchConfig
(int)new_request.tokens.size());
new_bc.requestsInfo[i].max_sequence_length =
new_request.max_sequence_length;

// add profile_info for the new request
ProfileInfo profile_info;
profile_info.decoding_steps = 0;
Expand All @@ -818,6 +843,10 @@ BeamSearchBatchConfig
new_bc.beamRequestsInfo[i].beam_size =
BeamSearchBatchConfig::MAX_BEAM_WIDTH;
new_bc.beamRequestsInfo[i].current_depth = 1;
new_bc.beamRequestsInfo[i].max_depth =
std::min(BeamSearchBatchConfig::MAX_BEAM_DEPTH,
BatchConfig::MAX_NUM_TOKENS -
new_bc.requestsInfo[i].num_tokens_in_batch - 1);
for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) {
new_bc.beamRequestsInfo[i].parent_id[j] = 0;
new_bc.beamRequestsInfo[i].probs[j] = 1;
Expand Down Expand Up @@ -947,7 +976,9 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify(
new_bc.num_tokens++;
new_bc.requestsInfo[i].num_tokens_in_batch++;
}
if (new_bc.num_tokens == BatchConfig::MAX_NUM_TOKENS) {

std::cout << "new_bc.num_tokens: " << new_bc.num_tokens << std::endl;
if (new_bc.num_tokens >= BatchConfig::MAX_NUM_TOKENS) {
assert(false &&
"Exceeding the space available in the TreeVerify batch");
break;
Expand Down Expand Up @@ -1065,6 +1096,8 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify(
break;
}
}

std::cout << "new_bc.num_tokens: " << new_bc.num_tokens << std::endl;
}

if (verbose) {
Expand Down Expand Up @@ -1346,9 +1379,11 @@ std::vector<std::pair<BatchConfig::TokenId, int>>
log_req_mgr.print("(%d, %d)", pair.first, pair.second);
}

assert(inputSerializedTree.size() == outputSerializedTree.size());
// It's safe to have inputSerializedTree.size() > outputSerializedTree.size()
// In this case the inputSeriedTree ends with padding 0s
assert(inputSerializedTree.size() >= outputSerializedTree.size());

for (int i = 0; i < inputSerializedTree.size(); i++) {
for (int i = 0; i < outputSerializedTree.size(); i++) {
auto input = inputSerializedTree.at(i);
auto output = outputSerializedTree.at(i);

Expand Down
24 changes: 18 additions & 6 deletions src/runtime/request_manager.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,16 @@ void RequestManager::load_tokens_task(
// BatchConfig const batch_config = *((BatchConfig *)task->args);
BatchConfig const *batch_config = BatchConfig::from_future(task->futures[0]);
BatchConfig::TokenId dram_copy[BatchConfig::MAX_NUM_TOKENS];
assert(batch_config->num_tokens <= BatchConfig::MAX_NUM_TOKENS);

// Extreme long prompts are not supported, only load up to MAX_NUM_TOKENS as
// prompt
if (batch_config->num_tokens > BatchConfig::MAX_NUM_TOKENS) {
printf("Warning: too many tokens in prompt, only load up to %d tokens\n",
BatchConfig::MAX_NUM_TOKENS);
printf("Got: %d tokens\n", batch_config->num_tokens);
}
// assert(batch_config->num_tokens <= BatchConfig::MAX_NUM_TOKENS);

for (int i = 0; i < batch_config->num_tokens; i++) {
dram_copy[i] = batch_config->tokensInfo[i].token_id;
}
Expand All @@ -42,11 +51,14 @@ void RequestManager::load_tokens_task(
assert(batch_config->num_tokens <= domain.get_volume());
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
checkCUDA(cudaMemcpyAsync(fb_ptr,
dram_copy,
sizeof(TokenId) * batch_config->num_tokens,
cudaMemcpyHostToDevice,
stream));
checkCUDA(
cudaMemcpyAsync(fb_ptr,
dram_copy,
batch_config->num_tokens <= BatchConfig::MAX_NUM_TOKENS
? sizeof(TokenId) * batch_config->num_tokens
: sizeof(TokenId) * BatchConfig::MAX_NUM_TOKENS,
cudaMemcpyHostToDevice,
stream));
}

void RequestManager::load_positions_task(
Expand Down

0 comments on commit f5bf9e6

Please sign in to comment.