From 3e292e55b031134ccf0856134c83a942ba89a5a3 Mon Sep 17 00:00:00 2001 From: Zeyu Wang Date: Thu, 28 Mar 2024 23:20:09 -0400 Subject: [PATCH] Split prefilling batch with decoding batch for increamental decoding. --- include/flexflow/request_manager.h | 3 + src/runtime/request_manager.cc | 130 ++++++++++++++++------------- 2 files changed, 76 insertions(+), 57 deletions(-) diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index a38a3b2671..331f717811 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -63,6 +63,7 @@ struct Request { RUNNING = 102, // running inference COMPLETED = 103, // finished and verified FINISHING = 104, // finishing request, but not yet verified + PREFILLING = 105 // prefilling the prompt }; BatchConfig::RequestGuid guid; int max_sequence_length; @@ -162,6 +163,7 @@ class RequestManager { InferenceResultFuture const &result, Legion::Context ctx, Legion::Runtime *runtime); + BatchConfig prepare_prefilling_batch(int i); BeamSearchBatchConfig prepare_next_batch_beam(BeamSearchBatchConfig const &old_bc, BeamInferenceResult const &result); @@ -306,6 +308,7 @@ class RequestManager { double start_time, finish_time; }; std::unordered_map profiling_requests; + BatchConfig buffer_bc = nullptr; double total_request_run_time; }; diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 16513e918a..92984cedb8 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -54,7 +54,6 @@ RequestManager::RequestManager() // ffmodel.compile() max_requests_per_batch = -1; max_tokens_per_batch = -1; - max_spec_tree_token_num = -1; max_sequence_length = -1; } @@ -76,27 +75,15 @@ void RequestManager::set_max_tokens_per_batch(int max_num_tokens) { assert(max_tokens_per_batch <= BatchConfig::MAX_NUM_TOKENS); } -void RequestManager::set_max_spec_tree_token_num(int max_num_tokens) { - assert(max_spec_tree_token_num == -1 || - max_spec_tree_token_num == max_num_tokens); - max_spec_tree_token_num = max_num_tokens; - assert(max_spec_tree_token_num <= BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); -} - int RequestManager::get_max_tokens_per_batch() { assert(max_tokens_per_batch > 0); return max_tokens_per_batch; } -int RequestManager::get_max_spec_tree_token_num() { - assert(max_spec_tree_token_num > 0); - return max_spec_tree_token_num; -} - int RequestManager::get_max_verify_tokens_per_batch() { assert(max_tokens_per_batch > 0); return max_tokens_per_batch + - max_spec_tree_token_num * max_requests_per_batch; + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM * max_requests_per_batch; } void RequestManager::set_max_sequence_length(int max_seq_length) { @@ -363,6 +350,57 @@ BatchConfig RequestManager::prepare_next_batch_task( return rm->prepare_next_batch(*bc, result); } +BatchConfig RequestManager::prepare_prefilling_batch(int i) { + const std::lock_guard lock(request_queue_mutex); + + BatchConfig new_bc; + + // mark empty requests as completed + for(int j = 0; j < BatchConfig::max_requests_per_batch(); j++) { + if (j == i) { + new_bc.request_completed[j] = false; + } else { + new_bc.request_completed[i] = true; + } + } + + // pop top request from the queue + Request new_request = pending_request_queue.front(); + pending_request_queue.pop(); + new_request.status = Request::PREFILLING; + all_requests[new_request.guid] = new_request; + + new_bc.requestsInfo[i].first_token_depth_in_request = 0; + new_bc.requestsInfo[i].first_token_offset_in_batch = 0; + new_bc.requestsInfo[i].request_guid = new_request.guid; + new_bc.requestsInfo[i].num_tokens_in_batch = + std::min(get_max_tokens_per_batch(), + (int)new_request.tokens.size()); + new_bc.requestsInfo[i].max_sequence_length = + new_request.max_sequence_length; + new_bc.request_completed[i] = false; + new_bc.requestsInfo[i].prompt_phase = true; + new_bc.requestsInfo[0].batch_config_request_id = i; + + // add profile_info for the new request + ProfileInfo profile_info; + profile_info.llm_decoding_steps = 1; + profile_info.start_time = Realm::Clock::current_time_in_microseconds(); + profiling_requests[new_request.guid] = profile_info; + + // add tokens to the batch + for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { + int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; + new_bc.tokensInfo[new_bc.num_tokens].request_index = i; + new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth; + assert(depth < new_request.tokens.size()); + new_bc.tokensInfo[new_bc.num_tokens].token_id = + new_request.tokens[depth]; + new_bc.num_tokens++; + } + return new_bc; +} + BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, InferenceResult const &result) { const std::lock_guard lock(request_queue_mutex); @@ -385,11 +423,21 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, // log_req_mgr.print("Output: %s", output.c_str()); } } + int num_generation_tokens = 0; int num_active_req = -1; // Step 2: prepare the next batch for existing requests BatchConfig new_bc; + if (buffer_bc != nullptr) { + new_bc = *buffer_bc; + buffer_bc = nullptr; + for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) { + if (!new_bc.request_completed[i]) { + num_active_req++; + } + } + for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) { if (old_bc.request_completed[i]) { // add new requests to the next batch continue; @@ -424,6 +472,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, gr.output_text = output; } request.status = Request::COMPLETED; + new_bc.request_completed[i] = true; trigger_request_completion_future(request.guid); log_req_mgr.print("[Done] guid(%zu) final_length(%zu)", old_bc.requestsInfo[i].request_guid, @@ -448,10 +497,10 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, std::ofstream outputFile(output_filepath, std::ios::app); if (outputFile.is_open()) { outputFile << "end-to-end latency: " << std::fixed - << std::setprecision(3) << total_request_run_time - << std::endl; + << std::setprecision(3) << total_request_run_time + << std::endl; outputFile << "num decoding steps: " - << profile_info.llm_decoding_steps << std::endl; + << profile_info.llm_decoding_steps << std::endl; outputFile << "token IDs: "; for (int i = 0; i < request.tokens.size(); i++) { outputFile << request.tokens[i]; @@ -489,8 +538,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, // Prompt phase new_bc.requestsInfo[i].num_tokens_in_batch = std::min(get_max_tokens_per_batch() - new_bc.num_tokens, - (int)request.tokens.size() - - new_bc.requestsInfo[i].first_token_depth_in_request); + (int)request.tokens.size() - + new_bc.requestsInfo[i].first_token_depth_in_request); new_bc.requestsInfo[i].prompt_phase = true; } for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { @@ -514,39 +563,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, if (new_bc.request_completed[i]) { if (!pending_request_queue.empty() && new_bc.num_tokens < get_max_tokens_per_batch()) { - Request new_request = pending_request_queue.front(); - pending_request_queue.pop(); - // all_requests[new_request.guid] = new_request; - - new_bc.requestsInfo[i].first_token_depth_in_request = 0; - new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; - new_bc.requestsInfo[i].request_guid = new_request.guid; - new_bc.requestsInfo[i].num_tokens_in_batch = - std::min(get_max_tokens_per_batch() - new_bc.num_tokens, - (int)new_request.tokens.size()); - new_bc.requestsInfo[i].max_sequence_length = - new_request.max_sequence_length; - new_bc.request_completed[i] = false; - new_bc.requestsInfo[i].prompt_phase = true; - num_active_req++; - new_bc.requestsInfo[num_active_req].batch_config_request_id = i; - // add profile_info for the new request - ProfileInfo profile_info; - profile_info.llm_decoding_steps = 1; - profile_info.start_time = Realm::Clock::current_time_in_microseconds(); - profiling_requests[new_request.guid] = profile_info; - for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { - int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; - new_bc.tokensInfo[new_bc.num_tokens].request_index = i; - new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth; - assert(depth < new_request.tokens.size()); - new_bc.tokensInfo[new_bc.num_tokens].token_id = - new_request.tokens[depth]; - new_bc.num_tokens++; - } - if (new_bc.num_tokens == get_max_tokens_per_batch()) { - break; - } + buffer_bc = &new_bc; + new_bc = prepare_prefilling_batch(i); } } } @@ -1577,11 +1595,9 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( } if (new_bc.num_tokens > get_max_verify_tokens_per_batch()) { - printf("Exceeding (%i) the space available (%i) in the TreeVerify " - "batch\n", - new_bc.num_tokens, - get_max_verify_tokens_per_batch()); - assert(false); + assert(false && + "Exceeding the space available in the TreeVerify batch"); + break; } if (new_bc.requestsInfo[i].num_tokens_in_batch +