Skip to content

Commit

Permalink
Split prefilling batch with decoding batch for increamental decoding.
Browse files Browse the repository at this point in the history
  • Loading branch information
zwang86 committed Mar 29, 2024
1 parent 0479a64 commit 3e292e5
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 57 deletions.
3 changes: 3 additions & 0 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ struct Request {
RUNNING = 102, // running inference
COMPLETED = 103, // finished and verified
FINISHING = 104, // finishing request, but not yet verified
PREFILLING = 105 // prefilling the prompt
};
BatchConfig::RequestGuid guid;
int max_sequence_length;
Expand Down Expand Up @@ -162,6 +163,7 @@ class RequestManager {
InferenceResultFuture const &result,
Legion::Context ctx,
Legion::Runtime *runtime);
BatchConfig prepare_prefilling_batch(int i);
BeamSearchBatchConfig
prepare_next_batch_beam(BeamSearchBatchConfig const &old_bc,
BeamInferenceResult const &result);
Expand Down Expand Up @@ -306,6 +308,7 @@ class RequestManager {
double start_time, finish_time;
};
std::unordered_map<RequestGuid, ProfileInfo> profiling_requests;
BatchConfig buffer_bc = nullptr;
double total_request_run_time;
};

Expand Down
130 changes: 73 additions & 57 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ RequestManager::RequestManager()
// ffmodel.compile()
max_requests_per_batch = -1;
max_tokens_per_batch = -1;
max_spec_tree_token_num = -1;
max_sequence_length = -1;
}

Expand All @@ -76,27 +75,15 @@ void RequestManager::set_max_tokens_per_batch(int max_num_tokens) {
assert(max_tokens_per_batch <= BatchConfig::MAX_NUM_TOKENS);
}

void RequestManager::set_max_spec_tree_token_num(int max_num_tokens) {
assert(max_spec_tree_token_num == -1 ||
max_spec_tree_token_num == max_num_tokens);
max_spec_tree_token_num = max_num_tokens;
assert(max_spec_tree_token_num <= BatchConfig::MAX_SPEC_TREE_TOKEN_NUM);
}

int RequestManager::get_max_tokens_per_batch() {
assert(max_tokens_per_batch > 0);
return max_tokens_per_batch;
}

int RequestManager::get_max_spec_tree_token_num() {
assert(max_spec_tree_token_num > 0);
return max_spec_tree_token_num;
}

int RequestManager::get_max_verify_tokens_per_batch() {
assert(max_tokens_per_batch > 0);
return max_tokens_per_batch +
max_spec_tree_token_num * max_requests_per_batch;
BatchConfig::MAX_SPEC_TREE_TOKEN_NUM * max_requests_per_batch;
}

void RequestManager::set_max_sequence_length(int max_seq_length) {
Expand Down Expand Up @@ -363,6 +350,57 @@ BatchConfig RequestManager::prepare_next_batch_task(
return rm->prepare_next_batch(*bc, result);
}

BatchConfig RequestManager::prepare_prefilling_batch(int i) {
const std::lock_guard<std::mutex> lock(request_queue_mutex);

BatchConfig new_bc;

// mark empty requests as completed
for(int j = 0; j < BatchConfig::max_requests_per_batch(); j++) {
if (j == i) {
new_bc.request_completed[j] = false;
} else {
new_bc.request_completed[i] = true;
}
}

// pop top request from the queue
Request new_request = pending_request_queue.front();
pending_request_queue.pop();
new_request.status = Request::PREFILLING;
all_requests[new_request.guid] = new_request;

new_bc.requestsInfo[i].first_token_depth_in_request = 0;
new_bc.requestsInfo[i].first_token_offset_in_batch = 0;
new_bc.requestsInfo[i].request_guid = new_request.guid;
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(get_max_tokens_per_batch(),
(int)new_request.tokens.size());
new_bc.requestsInfo[i].max_sequence_length =
new_request.max_sequence_length;
new_bc.request_completed[i] = false;
new_bc.requestsInfo[i].prompt_phase = true;
new_bc.requestsInfo[0].batch_config_request_id = i;

// add profile_info for the new request
ProfileInfo profile_info;
profile_info.llm_decoding_steps = 1;
profile_info.start_time = Realm::Clock::current_time_in_microseconds();
profiling_requests[new_request.guid] = profile_info;

// add tokens to the batch
for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) {
int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j;
new_bc.tokensInfo[new_bc.num_tokens].request_index = i;
new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth;
assert(depth < new_request.tokens.size());
new_bc.tokensInfo[new_bc.num_tokens].token_id =
new_request.tokens[depth];
new_bc.num_tokens++;
}
return new_bc;
}

BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
InferenceResult const &result) {
const std::lock_guard<std::mutex> lock(request_queue_mutex);
Expand All @@ -385,11 +423,21 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
// log_req_mgr.print("Output: %s", output.c_str());
}
}

int num_generation_tokens = 0;
int num_active_req = -1;

// Step 2: prepare the next batch for existing requests
BatchConfig new_bc;
if (buffer_bc != nullptr) {
new_bc = *buffer_bc;
buffer_bc = nullptr;
for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) {
if (!new_bc.request_completed[i]) {
num_active_req++;
}
}

for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) {
if (old_bc.request_completed[i]) { // add new requests to the next batch
continue;
Expand Down Expand Up @@ -424,6 +472,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
gr.output_text = output;
}
request.status = Request::COMPLETED;
new_bc.request_completed[i] = true;
trigger_request_completion_future(request.guid);
log_req_mgr.print("[Done] guid(%zu) final_length(%zu)",
old_bc.requestsInfo[i].request_guid,
Expand All @@ -448,10 +497,10 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
std::ofstream outputFile(output_filepath, std::ios::app);
if (outputFile.is_open()) {
outputFile << "end-to-end latency: " << std::fixed
<< std::setprecision(3) << total_request_run_time
<< std::endl;
<< std::setprecision(3) << total_request_run_time
<< std::endl;
outputFile << "num decoding steps: "
<< profile_info.llm_decoding_steps << std::endl;
<< profile_info.llm_decoding_steps << std::endl;
outputFile << "token IDs: ";
for (int i = 0; i < request.tokens.size(); i++) {
outputFile << request.tokens[i];
Expand Down Expand Up @@ -489,8 +538,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
// Prompt phase
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(get_max_tokens_per_batch() - new_bc.num_tokens,
(int)request.tokens.size() -
new_bc.requestsInfo[i].first_token_depth_in_request);
(int)request.tokens.size() -
new_bc.requestsInfo[i].first_token_depth_in_request);
new_bc.requestsInfo[i].prompt_phase = true;
}
for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) {
Expand All @@ -514,39 +563,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
if (new_bc.request_completed[i]) {
if (!pending_request_queue.empty() &&
new_bc.num_tokens < get_max_tokens_per_batch()) {
Request new_request = pending_request_queue.front();
pending_request_queue.pop();
// all_requests[new_request.guid] = new_request;

new_bc.requestsInfo[i].first_token_depth_in_request = 0;
new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens;
new_bc.requestsInfo[i].request_guid = new_request.guid;
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(get_max_tokens_per_batch() - new_bc.num_tokens,
(int)new_request.tokens.size());
new_bc.requestsInfo[i].max_sequence_length =
new_request.max_sequence_length;
new_bc.request_completed[i] = false;
new_bc.requestsInfo[i].prompt_phase = true;
num_active_req++;
new_bc.requestsInfo[num_active_req].batch_config_request_id = i;
// add profile_info for the new request
ProfileInfo profile_info;
profile_info.llm_decoding_steps = 1;
profile_info.start_time = Realm::Clock::current_time_in_microseconds();
profiling_requests[new_request.guid] = profile_info;
for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) {
int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j;
new_bc.tokensInfo[new_bc.num_tokens].request_index = i;
new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth;
assert(depth < new_request.tokens.size());
new_bc.tokensInfo[new_bc.num_tokens].token_id =
new_request.tokens[depth];
new_bc.num_tokens++;
}
if (new_bc.num_tokens == get_max_tokens_per_batch()) {
break;
}
buffer_bc = &new_bc;
new_bc = prepare_prefilling_batch(i);
}
}
}
Expand Down Expand Up @@ -1577,11 +1595,9 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify(
}

if (new_bc.num_tokens > get_max_verify_tokens_per_batch()) {
printf("Exceeding (%i) the space available (%i) in the TreeVerify "
"batch\n",
new_bc.num_tokens,
get_max_verify_tokens_per_batch());
assert(false);
assert(false &&
"Exceeding the space available in the TreeVerify batch");
break;
}

if (new_bc.requestsInfo[i].num_tokens_in_batch +
Expand Down

0 comments on commit 3e292e5

Please sign in to comment.