Skip to content

Commit

Permalink
Discard long prompt request.
Browse files Browse the repository at this point in the history
  • Loading branch information
zwang86 committed Jul 27, 2023
1 parent f5bf9e6 commit 5e52aa6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
26 changes: 21 additions & 5 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,13 @@ RequestManager::RequestGuid
<< BatchConfig::MAX_PROMPT_LENGTH << " tokens, but got "
<< prompt.size() << ".\n";
// Truncate the prompt to MAX_NUM_TOKENS
request.tokens.insert(request.tokens.end(),
prompt.begin(),
prompt.begin() + BatchConfig::MAX_PROMPT_LENGTH);
request.initial_len = BatchConfig::MAX_PROMPT_LENGTH;
// request.tokens.insert(request.tokens.end(),
// prompt.begin(),
// prompt.begin() + BatchConfig::MAX_PROMPT_LENGTH);
// request.initial_len = BatchConfig::MAX_PROMPT_LENGTH;
printf("tokens size: %zu\n", request.tokens.size());
// assert(false);
return 0;
} else {
request.initial_len = prompt.size();
request.tokens = prompt;
Expand Down Expand Up @@ -194,9 +195,10 @@ RequestManager::RequestGuid
<< BatchConfig::MAX_PROMPT_LENGTH << " tokens, but got "
<< tokens.size() << ".\n";
// Truncate the prompt to MAX_NUM_TOKENS
tokens.resize(BatchConfig::MAX_PROMPT_LENGTH);
// tokens.resize(BatchConfig::MAX_PROMPT_LENGTH);
printf("tokens size: %zu\n", tokens.size());
// assert(false);
return 0;
}

for (int i = 0; i < tokens.size(); i++) {
Expand Down Expand Up @@ -1562,6 +1564,13 @@ GenerationResult RequestManager::generate_incr_decoding(FFModel *llm,
int max_seq_length) {
InferenceManager *im = InferenceManager::get_inference_manager();
RequestGuid guid = register_new_request(text, max_seq_length);
if (guid == 0) {
std::cout
<< "=========== Discard request exceed prompt maximum... ==========="
<< std::endl;
return GenerationResult();
}

int tokens_to_generate = max_seq_length - all_requests[guid].tokens.size();
std::queue<std::pair<BatchConfigFuture, InferenceResultFuture>>
batch_pipeline;
Expand Down Expand Up @@ -1605,6 +1614,13 @@ GenerationResult RequestManager::generate_spec_infer(FFModel *llm,
int max_seq_length) {
InferenceManager *im = InferenceManager::get_inference_manager();
RequestGuid guid = register_new_request(text, max_seq_length);
if (guid == 0) {
std::cout
<< "=========== Discard request exceed prompt maximum... ==========="
<< std::endl;
return GenerationResult();
}

std::queue<std::pair<TreeVerifyBatchConfigFuture, InferenceResultFuture>>
batch_pipeline;
batch_pipeline.push(std::make_pair(last_tree_bcf, last_tree_irf));
Expand Down
13 changes: 5 additions & 8 deletions src/runtime/request_manager.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,11 @@ void RequestManager::load_tokens_task(
assert(batch_config->num_tokens <= domain.get_volume());
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
checkCUDA(
cudaMemcpyAsync(fb_ptr,
dram_copy,
batch_config->num_tokens <= BatchConfig::MAX_NUM_TOKENS
? sizeof(TokenId) * batch_config->num_tokens
: sizeof(TokenId) * BatchConfig::MAX_NUM_TOKENS,
cudaMemcpyHostToDevice,
stream));
checkCUDA(cudaMemcpyAsync(fb_ptr,
dram_copy,
sizeof(TokenId) * batch_config->num_tokens,
cudaMemcpyHostToDevice,
stream));
}

void RequestManager::load_positions_task(
Expand Down

0 comments on commit 5e52aa6

Please sign in to comment.