From 5e52aa637fbdc7ed3287640164a756ea462dee74 Mon Sep 17 00:00:00 2001 From: Zeyu Wang Date: Thu, 27 Jul 2023 03:13:37 +0000 Subject: [PATCH] Discard long prompt request. --- src/runtime/request_manager.cc | 26 +++++++++++++++++++++----- src/runtime/request_manager.cu | 13 +++++-------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 514d9d8c6e..8aed0de88a 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -133,12 +133,13 @@ RequestManager::RequestGuid << BatchConfig::MAX_PROMPT_LENGTH << " tokens, but got " << prompt.size() << ".\n"; // Truncate the prompt to MAX_NUM_TOKENS - request.tokens.insert(request.tokens.end(), - prompt.begin(), - prompt.begin() + BatchConfig::MAX_PROMPT_LENGTH); - request.initial_len = BatchConfig::MAX_PROMPT_LENGTH; + // request.tokens.insert(request.tokens.end(), + // prompt.begin(), + // prompt.begin() + BatchConfig::MAX_PROMPT_LENGTH); + // request.initial_len = BatchConfig::MAX_PROMPT_LENGTH; printf("tokens size: %zu\n", request.tokens.size()); // assert(false); + return 0; } else { request.initial_len = prompt.size(); request.tokens = prompt; @@ -194,9 +195,10 @@ RequestManager::RequestGuid << BatchConfig::MAX_PROMPT_LENGTH << " tokens, but got " << tokens.size() << ".\n"; // Truncate the prompt to MAX_NUM_TOKENS - tokens.resize(BatchConfig::MAX_PROMPT_LENGTH); + // tokens.resize(BatchConfig::MAX_PROMPT_LENGTH); printf("tokens size: %zu\n", tokens.size()); // assert(false); + return 0; } for (int i = 0; i < tokens.size(); i++) { @@ -1562,6 +1564,13 @@ GenerationResult RequestManager::generate_incr_decoding(FFModel *llm, int max_seq_length) { InferenceManager *im = InferenceManager::get_inference_manager(); RequestGuid guid = register_new_request(text, max_seq_length); + if (guid == 0) { + std::cout + << "=========== Discard request exceed prompt maximum... ===========" + << std::endl; + return GenerationResult(); + } + int tokens_to_generate = max_seq_length - all_requests[guid].tokens.size(); std::queue> batch_pipeline; @@ -1605,6 +1614,13 @@ GenerationResult RequestManager::generate_spec_infer(FFModel *llm, int max_seq_length) { InferenceManager *im = InferenceManager::get_inference_manager(); RequestGuid guid = register_new_request(text, max_seq_length); + if (guid == 0) { + std::cout + << "=========== Discard request exceed prompt maximum... ===========" + << std::endl; + return GenerationResult(); + } + std::queue> batch_pipeline; batch_pipeline.push(std::make_pair(last_tree_bcf, last_tree_irf)); diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index 1bbda58b02..abfcd72a38 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -51,14 +51,11 @@ void RequestManager::load_tokens_task( assert(batch_config->num_tokens <= domain.get_volume()); cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - checkCUDA( - cudaMemcpyAsync(fb_ptr, - dram_copy, - batch_config->num_tokens <= BatchConfig::MAX_NUM_TOKENS - ? sizeof(TokenId) * batch_config->num_tokens - : sizeof(TokenId) * BatchConfig::MAX_NUM_TOKENS, - cudaMemcpyHostToDevice, - stream)); + checkCUDA(cudaMemcpyAsync(fb_ptr, + dram_copy, + sizeof(TokenId) * batch_config->num_tokens, + cudaMemcpyHostToDevice, + stream)); } void RequestManager::load_positions_task(