diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index ce331d3e41..8aa69a3cad 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -46,13 +46,14 @@ class BatchConfig { void print() const; virtual InferenceMode get_mode() const; static BatchConfig const *from_future(BatchConfigFuture const &future); - static int const MAX_NUM_REQUESTS = 1; + static int const MAX_NUM_REQUESTS = 4; static int const MAX_NUM_TOKENS = 64; static int const MAX_PROMPT_LENGTH = 62; static int const MAX_SEQ_LENGTH = 256; // These are set by update int num_tokens; + bool loading_prompt = false; struct PerRequestInfo { int token_start_offset; @@ -69,6 +70,7 @@ class BatchConfig { PerTokenInfo tokensInfo[MAX_NUM_TOKENS]; bool request_completed[MAX_NUM_REQUESTS]; + bool request_running[MAX_NUM_TOKENS]; }; class TreeVerifyBatchConfig : public BatchConfig { @@ -113,7 +115,6 @@ class BeamSearchBatchConfig : public BatchConfig { inline static int const MAX_BEAM_DEPTH = 8; int model_id; - int max_init_length = 0; struct BeamSearchPerRequestInfo { int beam_size; diff --git a/include/flexflow/model.h b/include/flexflow/model.h index f88f96cd5a..177575e809 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -239,8 +239,8 @@ enum TaskIDs { RM_LOAD_TOKENS_TASK_ID, RM_LOAD_POSITION_TASK_ID, RM_PREPARE_NEXT_BATCH_TASK_ID, - RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID, RM_PREPARE_NEXT_BATCH_INIT_TASK_ID, + RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID, RM_PREPARE_NEXT_BATCH_VERIFY_TASK_ID, // Custom tasks CUSTOM_GPU_TASK_ID_FIRST, @@ -787,7 +787,8 @@ class FFModel { // ======================================== // Inference APIs // ======================================== - GenerationResult generate(std::string const &text, int max_seq_length); + GenerationResult generate(std::vector &prompts, + int max_seq_length); Tensor create_tensor_legion_ordering(int num_dim, int const dims[], diff --git a/include/flexflow/ops/kernels/softmax_kernels.h b/include/flexflow/ops/kernels/softmax_kernels.h index 14c07414e9..987a546459 100644 --- a/include/flexflow/ops/kernels/softmax_kernels.h +++ b/include/flexflow/ops/kernels/softmax_kernels.h @@ -15,8 +15,10 @@ class SoftmaxMeta : public OpMeta { Legion::Domain const &input_domain); #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) cudnnTensorDescriptor_t inputTensor; + cudnnTensorDescriptor_t outputTensor; #else miopenTensorDescriptor_t inputTensor; + miopenTensorDescriptor_t outputTensor; #endif bool profiling; int dim; diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index e444402dd0..8515d8a04b 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -52,13 +52,17 @@ class InferenceManager { struct Request { enum Status { - PENDING = 101, - RUNNING = 102, - COMPLETED = 103, + PENDING = 101, // loading prompt + RUNNING = 102, // running inference + COMPLETED = 103, // finished and verified + FINISHING = 104, // finishing request, but not yet verified }; BatchConfig::RequestGuid guid; int max_sequence_length; int initial_len; + int ssm_cache_size = 0; + int llm_cache_size = 0; + Status status = PENDING; std::vector tokens; @@ -102,10 +106,10 @@ class RequestManager { FFModel *get_model(int model_id); GenerationResult generate_incr_decoding(FFModel *model, - std::string const &text, + std::vector &prompts, int max_seq_length); GenerationResult generate_spec_infer(FFModel *model, - std::string const &text, + std::vector &prompts, int max_seq_length); GenerationResult get_generation_result(RequestGuid const &guid); RequestGuid register_new_request(std::string const &prompt, diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index 19cd8726e2..3f913e4573 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -242,13 +242,15 @@ void FlexFlow::top_level_task(Task const *task, /*parser_callback_t */ nullptr, /*allow_exceptions */ true, /*ignore_comments */ true); + std::vector prompts; for (auto &prompt : prompt_json) { std::string text = prompt.get(); printf("Prompt[%d]: %s\n", total_num_requests, text.c_str()); total_num_requests++; - GenerationResult result = - model.generate(text, 128 /*max_sequence_length*/); + prompts.push_back(text); } + GenerationResult result = + model.generate(prompts, 128 /*max_sequence_length*/); } // Execution fence diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 9d139997f7..2b1fb6e817 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -384,12 +384,16 @@ void FlexFlow::top_level_task(Task const *task, /*parser_callback_t */ nullptr, /*allow_exceptions */ true, /*ignore_comments */ true); + + std::vector prompts; for (auto &prompt : prompt_json) { std::string text = prompt.get(); printf("Prompt[%d]: %s\n", total_num_requests, text.c_str()); total_num_requests++; - tree_model.generate(text, 128 /*max_sequence_length*/); + prompts.push_back(text); + // tree_model.generate(text, 128 /*max_sequence_length*/); } + tree_model.generate(prompts, 128 /*max_sequence_length*/); } // Execution fence diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index 0ebe29e3e9..fcdae9cf33 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -1529,8 +1529,10 @@ flexflow_generation_result_t int max_seq_length, int *output_length_and_tokens) { FFModel *handle = FFCObjectWrapper::unwrap(handle_); + std::vector prompts; std::string const text_str(input_text); - GenerationResult result = handle->generate(text_str, max_seq_length); + prompts.push_back(input_text); + GenerationResult result = handle->generate(prompts, max_seq_length); DEBUG_PRINT("[Model] generate %p %s %i", handle, text_str, max_seq_length); assert(result.output_tokens.size() <= max_seq_length); output_length_and_tokens[0] = result.output_tokens.size(); diff --git a/src/mapper/mapper.cc b/src/mapper/mapper.cc index 3d08eb0bcc..a86a6167a6 100644 --- a/src/mapper/mapper.cc +++ b/src/mapper/mapper.cc @@ -284,8 +284,8 @@ void FFMapper::select_task_options(const MapperContext ctx, return; } if ((task.task_id == RM_PREPARE_NEXT_BATCH_TASK_ID) || - (task.task_id == RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID) || (task.task_id == RM_PREPARE_NEXT_BATCH_INIT_TASK_ID) || + (task.task_id == RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID) || (task.task_id == RM_PREPARE_NEXT_BATCH_VERIFY_TASK_ID)) { output.initial_proc = all_cpus[0]; return; diff --git a/src/ops/argmax.cpp b/src/ops/argmax.cpp index ec5ea6c36a..8a1cf0b3b0 100644 --- a/src/ops/argmax.cpp +++ b/src/ops/argmax.cpp @@ -393,7 +393,7 @@ void ArgMax::forward_kernel(ArgMaxMeta const *m, if (m->beam_search) { // set all parents id zero in arg top1 case. - checkCUDA(hipMemset(parent, 0, batch_size * sizeof(int))); + checkCUDA(hipMemsetAsync(parent, 0, batch_size * sizeof(int), stream)); } int num_shards = 0; int k = 1; diff --git a/src/ops/argmax.cu b/src/ops/argmax.cu index 37e067006c..05c84719c1 100644 --- a/src/ops/argmax.cu +++ b/src/ops/argmax.cu @@ -59,7 +59,7 @@ void ArgMax::forward_kernel(ArgMaxMeta const *m, DT alpha = 1.0f, beta = 0.0f; if (m->beam_search) { // set all parents id zero in arg top1 case. - checkCUDA(cudaMemset(parent, 0, batch_size * sizeof(int))); + checkCUDA(cudaMemsetAsync(parent, 0, batch_size * sizeof(int), stream)); } size_t temp_storage_bytes = m->temp_storage_bytes; // use cub @@ -83,6 +83,7 @@ void ArgMax::forward_kernel(ArgMaxMeta const *m, prob_ptr, batch_size, m->beam_search); + // print_tensor(indices_ptr, 32, "argmax op"); } /*static*/ @@ -93,7 +94,6 @@ void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m, int batch_size) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - cudaEvent_t t_start, t_end; if (m->profiling) { cudaEventCreate(&t_start); diff --git a/src/ops/kernels/softmax.cpp b/src/ops/kernels/softmax.cpp index bd8b46116d..ca4872d51b 100644 --- a/src/ops/kernels/softmax.cpp +++ b/src/ops/kernels/softmax.cpp @@ -29,6 +29,9 @@ SoftmaxMeta::SoftmaxMeta(FFHandler handler, checkCUDNN(miopenCreateTensorDescriptor(&inputTensor)); checkCUDNN( cudnnSetTensorDescriptorFromDomain4SoftMax(inputTensor, input_domain)); + checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); + checkCUDNN( + cudnnSetTensorDescriptorFromDomain4SoftMax(outputTensor, input_domain)); dim = softmax->dim; profiling = softmax->profiling; std::strcpy(op_name, softmax->name); @@ -127,7 +130,7 @@ void forward_kernel(SoftmaxMeta const *m, m->inputTensor, input_ptr, &beta, - m->inputTensor, + m->outputTensor, output_ptr, MIOPEN_SOFTMAX_ACCURATE, MIOPEN_SOFTMAX_MODE_CHANNEL)); diff --git a/src/ops/kernels/softmax.cu b/src/ops/kernels/softmax.cu index 15130c19a7..67a9c21038 100644 --- a/src/ops/kernels/softmax.cu +++ b/src/ops/kernels/softmax.cu @@ -28,6 +28,9 @@ SoftmaxMeta::SoftmaxMeta(FFHandler handler, checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); checkCUDNN(cudnnSetTensorDescriptorFromDomain4SoftMax( inputTensor, input_domain, softmax->data_type)); + checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); + checkCUDNN(cudnnSetTensorDescriptorFromDomain4SoftMax( + outputTensor, input_domain, softmax->data_type)); dim = softmax->dim; profiling = softmax->profiling; std::strcpy(op_name, softmax->name); @@ -42,7 +45,6 @@ void forward_kernel_wrapper(SoftmaxMeta const *m, DT *output_ptr) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - cudaEvent_t t_start, t_end; if (m->profiling) { cudaEventCreate(&t_start); @@ -127,7 +129,7 @@ void forward_kernel(SoftmaxMeta const *m, m->inputTensor, input_ptr, &beta, - m->inputTensor, + m->outputTensor, output_ptr)); } diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 6ef5145654..b4cdc77e2a 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -251,6 +251,7 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, if (bc->request_completed[i]) { continue; } + for (int sub_req_id = 0; sub_req_id < bc->sub_requests[i]; sub_req_id++) { // int num_new_tokens = bc->num_processing_tokens[i]; @@ -259,6 +260,11 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; int total_tokens = bc->requestsInfo[i].token_start_offset + bc->requestsInfo[i].num_tokens_in_batch; + + if (num_new_tokens <= 0) { + continue; + } + // Compute (QK^T/sqrt(d_k)) int m_ = num_new_tokens; int n = total_tokens; @@ -543,7 +549,7 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, output_ptr, bias_ptr, num_tokens, qkv_weight_size, m->oProjSize); } - assert(tokens_previous_requests == num_tokens); + // assert(tokens_previous_requests == num_tokens); } template diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 5eb3192e25..5489c9b06d 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -143,17 +143,12 @@ RequestManager::RequestGuid request.guid = next_available_guid++; request.max_sequence_length = max_sequence_length; - if (prompt.size() > BatchConfig::MAX_PROMPT_LENGTH) { + if (prompt.size() >= BatchConfig::MAX_SEQ_LENGTH) { std::cout << "Warning: too many tokens in prompt, only load up to " - << BatchConfig::MAX_PROMPT_LENGTH << " tokens, but got " + << BatchConfig::MAX_SEQ_LENGTH << " tokens, but got " << prompt.size() << ".\n"; - // Truncate the prompt to MAX_NUM_TOKENS - // request.tokens.insert(request.tokens.end(), - // prompt.begin(), - // prompt.begin() + BatchConfig::MAX_PROMPT_LENGTH); - // request.initial_len = BatchConfig::MAX_PROMPT_LENGTH; + printf("tokens size: %zu\n", request.tokens.size()); - // assert(false); return 0; } else { request.initial_len = prompt.size(); @@ -206,14 +201,12 @@ RequestManager::RequestGuid request.tokens.push_back(bos_token_id); } std::vector tokens = this->tokenizer_->Encode(prompt); - if (tokens.size() > BatchConfig::MAX_PROMPT_LENGTH) { + if (tokens.size() >= BatchConfig::MAX_SEQ_LENGTH) { std::cout << "Warning: too many tokens in prompt, only load up to " - << BatchConfig::MAX_PROMPT_LENGTH << " tokens, but got " + << BatchConfig::MAX_SEQ_LENGTH << " tokens, but got " << tokens.size() << ".\n"; - // Truncate the prompt to MAX_NUM_TOKENS - // tokens.resize(BatchConfig::MAX_PROMPT_LENGTH); + printf("tokens size: %zu\n", tokens.size()); - // assert(false); return 0; } for (int i = 0; i < tokens.size(); i++) { @@ -238,6 +231,7 @@ RequestManager::RequestGuid all_requests[request.guid] = request; { std::string output = "New request tokens:"; + output = "[" + std::to_string(request.guid) + "]" + output; for (int i = 0; i < request.tokens.size(); i++) { output = output + " " + std::to_string(request.tokens[i]); } @@ -467,149 +461,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, } /* ----- Speculative Inference Specific functions ----- */ -BeamSearchBatchConfigFuture RequestManager::prepare_next_batch_beam( - BeamSearchBatchConfigFuture const &old_bc, - BeamInferenceResultFuture const &result) { - Runtime *runtime = Runtime::get_runtime(); - Context ctx = Runtime::get_context(); - - RequestManager *rm = this; - TaskLauncher launcher(RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID, - TaskArgument(&rm, sizeof(RequestManager *))); - launcher.add_future(old_bc); - launcher.add_future(result); - return runtime->execute_task(ctx, launcher); -} - -BeamSearchBatchConfig RequestManager::prepare_next_batch_beam_task( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - RequestManager *rm = *((RequestManager **)task->args); - BeamSearchBatchConfig const &bc = - Future(task->futures[0]).get_result(); - BeamInferenceResult const &result = - Future(task->futures[1]).get_result(); - return rm->prepare_next_batch_beam(bc, result); -} - -// update beam search metadata -BeamSearchBatchConfig - RequestManager::prepare_next_batch_beam(BeamSearchBatchConfig const &old_bc, - BeamInferenceResult const &result) { - const std::lock_guard lock(request_queue_mutex); - if (verbose) { - std::cout << "\n############### prepare_next_batch_beam ###############\n"; - } - if (verbose) { - std::cout << "print all results" - << "\n"; - for (int i = 0; i < 40; i++) { - std::cout << result.token_ids[i] << ", "; - } - std::cout << "Current Beam Depth: " - << old_bc.beamRequestsInfo[0].current_depth << "\n"; - } - - // Step 1: Store result to the beam tree struct - store_beam_metadata(old_bc, result); - - // Step 2: preparing the next batch for existing requests - BeamSearchBatchConfig new_bc; - new_bc.max_init_length = 0; - new_bc.model_id = old_bc.model_id; - // std::cout << "old_bc.model_id: " << old_bc.model_id << "\n"; - - for (int i = 0; i < BatchConfig::MAX_NUM_REQUESTS; i++) { - if (old_bc.request_completed[i]) { - continue; - } - // Comment out this assertion since num_tokens_in_batch can be - // zero when beam search has reached required sequence length - // assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0); - Request &request = all_requests[old_bc.requestsInfo[i].request_guid]; - int processed_tokens = old_bc.requestsInfo[i].token_start_offset + - old_bc.requestsInfo[i].num_tokens_in_batch; - - // assert(processed_tokens < request.tokens.size()); - log_req_mgr.debug() << "processed_tokens: " << processed_tokens << "\n"; - if (processed_tokens > - old_bc.beamRequestsInfo[i].max_depth + request.tokens.size() - // || ir.results[t] == 0 TODO: replace this with - ) { - log_req_mgr.print("[Done] guid(%zu) with spec_tree_depth(%d)", - old_bc.requestsInfo[i].request_guid, - old_bc.beamRequestsInfo[i].max_depth); - // new_bc.request_completed[i] = true; - new_bc.request_completed[i] = false; - new_bc.requestsInfo[i].token_start_offset = processed_tokens; - new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; - new_bc.requestsInfo[i].max_sequence_length = - old_bc.requestsInfo[i].max_sequence_length; - } else { - log_req_mgr.debug() << "num tokens: " << old_bc.num_tokens << ", " - << new_bc.num_tokens; - new_bc.request_completed[i] = false; - new_bc.requestsInfo[i].token_start_offset = processed_tokens; - new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; - new_bc.requestsInfo[i].max_sequence_length = - old_bc.requestsInfo[i].max_sequence_length; - - // update the beam search metadata - // how many sub request in current request - // why is sub_requests has MAX_NUM_REQUESTS * MAX_BEAM_WIDTH entries? - new_bc.sub_requests[i] = old_bc.beamRequestsInfo[i].beam_size; - // update the parentid, accumalated_probs, depth, and token_ids - new_bc.beamRequestsInfo[i].current_depth = - old_bc.beamRequestsInfo[i].current_depth + 1; - new_bc.beamRequestsInfo[i].beam_size = - old_bc.beamRequestsInfo[i].beam_size; - new_bc.beamRequestsInfo[i].max_depth = - old_bc.beamRequestsInfo[i].max_depth; - - // do the slot exchange to minimize the cache exchange in kernel. - // std::cout << "update metadata" << std::endl; - update_beam_metadata(new_bc, request.beam_trees.at(old_bc.model_id), i); - - if (new_bc.requestsInfo[i].token_start_offset + 1 >= - request.tokens.size()) { - // Incremental phase - new_bc.requestsInfo[i].num_tokens_in_batch = 1; - } else { - // Prompt phase - new_bc.requestsInfo[i].num_tokens_in_batch = - std::min(BatchConfig::MAX_NUM_TOKENS - new_bc.num_tokens, - (int)request.tokens.size() - - new_bc.requestsInfo[i].token_start_offset); - } - - // register more tokens due to the beam width - for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { - int depth = new_bc.requestsInfo[i].token_start_offset + j; - for (int k = 0; k < new_bc.sub_requests[i]; k++) { - new_bc.tokensInfo[new_bc.num_tokens].request_index = i; - new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth; - - // get value from requestinfo - new_bc.tokensInfo[new_bc.num_tokens].token_id = - new_bc.beamRequestsInfo[i].tokens[k]; - // request.tokens[depth]; - new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = k; - new_bc.num_tokens++; - } - } - } - } - if (verbose) { - std::cout << "prepare_next_batch_beam OLD vs NEW batchconfigs:" - << std::endl; - old_bc.print(); - new_bc.print(); - } - return new_bc; -} +/***** Request Init Phase *****/ BeamSearchBatchConfigFuture RequestManager::prepare_next_batch_init( TreeVerifyBatchConfigFuture const &old_bc, InferenceResultFuture const &result, @@ -648,6 +501,9 @@ BeamSearchBatchConfig if (verbose) { std::cout << "\n############### prepare_next_batch_init ###############\n"; } + + std::cout << "\n############### prepare_next_batch_init ###############\n"; + // Step 1: use result to update requests BeamSearchBatchConfig new_bc; new_bc.num_tokens = 0; @@ -661,188 +517,226 @@ BeamSearchBatchConfig size_t guid = old_bc.requestsInfo[i].request_guid; Request &request = all_requests[guid]; + std::cout << "[ " << guid << " ]" << std::endl; + // Verify this: get verified tokens from result std::vector> tree_outputs = std::vector>(); assert(old_bc.num_tokens > 0); - int start_depth = old_bc.tokensInfo[result_index].abs_depth_in_request; - if (committed_tokens.find(guid) == committed_tokens.end()) { - committed_tokens[guid] = std::vector>(); + // reset committed_tokens + if (committed_tokens.count(guid) == 0) { + committed_tokens[guid] = {}; } else { - committed_tokens.at(guid).clear(); + committed_tokens[guid].clear(); } + // iterate through all the tokens that belong to request i + int root_abs_depth = request.tokens.size() - 1; + while (result_index < old_bc.num_tokens && old_bc.tokensInfo[result_index].request_index == i) { - // new tokens have not been appended yet, so the last appended token is - // the root of the beam search token tree - int root_abs_depth = request.tokens.size() - 1; - if (old_bc.tokensInfo[result_index].abs_depth_in_request >= - root_abs_depth) { - // append to tree_outputs a pair consisting of (token id, depth) - tree_outputs.push_back(std::make_pair( - result.token_ids[result_index], - old_bc.tokensInfo[result_index].abs_depth_in_request + 1)); - // append (depth, index of the token in result) to committed_tokens - // array - committed_tokens.at(guid).push_back( - std::make_pair(old_bc.tokensInfo[result_index].abs_depth_in_request, - result_index)); + int abs_depth = old_bc.tokensInfo[result_index].abs_depth_in_request; + int token_id = result.token_ids[result_index]; + + if (request.status == Request::PENDING) { + committed_tokens[guid].emplace_back(abs_depth, result_index); + } else if (abs_depth >= root_abs_depth) { + tree_outputs.emplace_back(token_id, abs_depth + 1); + committed_tokens[guid].emplace_back(abs_depth, result_index); if (verbose) { std::cout << "Index within old batch: " << result_index << std::endl; printf(" Input: [%d] %d ---> [%d] %d \n", - old_bc.tokensInfo[result_index].abs_depth_in_request, + abs_depth, old_bc.tokensInfo[result_index].token_id, tree_outputs.back().second, - tree_outputs.back().first); + token_id); } - // std::cout << " Input: " << old_bc.tokensInfo[result_index].token_id - // << "" - // << old_bc.tokensInfo[result_index].abs_depth_in_request << - // std::endl; - // std::cout << " Result: " << result.token_ids[result_index] << ", - // depth: " - // << old_bc.tokensInfo[result_index].abs_depth_in_request + 1 << - // std::endl; + std::cout << "Index within old batch: " << result_index << std::endl; + printf(" Input: [%d] %d ---> [%d] %d \n", + abs_depth, + old_bc.tokensInfo[result_index].token_id, + tree_outputs.back().second, + token_id); } result_index++; } - std::vector> verified_tokens = - traverse_verify_tree(guid, dfs_tree_inputs.at(guid), tree_outputs); - log_req_mgr.print("Number of Verified Tokens = %zu", - verified_tokens.size()); - // check if the request is finished - if (verified_tokens.size() + request.tokens.size() >= - request.max_sequence_length) { - // Append all verified tokens to the request - for (int j = 0; j < verified_tokens.size(); j++) { - if (verified_tokens[j].second < request.max_sequence_length) { - request.tokens.push_back(verified_tokens[j].first); + if (request.status == Request::RUNNING) { + std::vector> verified_tokens = + traverse_verify_tree(guid, dfs_tree_inputs.at(guid), tree_outputs); + log_req_mgr.print("Number of Verified Tokens = %zu", + verified_tokens.size()); + + // check if the request is finished + if (verified_tokens.size() + request.tokens.size() >= + request.max_sequence_length) { + // Append all verified tokens to the request + for (auto const &token_pair : verified_tokens) { + if (token_pair.second < request.max_sequence_length) { + request.tokens.push_back(token_pair.first); + } } - } - request.status = Request::COMPLETED; - log_req_mgr.print("[Done] guid(%zu) with final length(%zu)", - request.guid, - request.tokens.size()); - std::string output = this->tokenizer_->Decode(request.tokens); - { - // update generation result and trigger future - GenerationResult &gr = request_generation_results[request.guid]; - assert(gr.guid == request.guid); - gr.output_tokens = request.tokens; - gr.output_text = output; - } - log_req_mgr.print("Final output: %s", output.c_str()); - new_bc.request_completed[i] = true; - num_processed_requests++; - ProfileInfo profile_info = profiling_requests[request.guid]; - profile_info.finish_time = Realm::Clock::current_time_in_microseconds(); - total_request_run_time += - profile_info.finish_time - profile_info.start_time; - profiling_requests[request.guid] = profile_info; - log_req_mgr.print("[Profile] guid(%zu) decoding_steps(%d) start(%.1lf) " - "finish(%.1lf) latency(%.1lf)", - request.guid, - profile_info.decoding_steps, - profile_info.start_time, - profile_info.finish_time, - profile_info.finish_time - profile_info.start_time); - - // Write output to file if needed: - if (!output_filepath.empty()) { - std::ofstream outputFile(output_filepath); - if (outputFile.is_open()) { - outputFile << "end-to-end latency: " << std::fixed - << std::setprecision(3) << total_request_run_time - << std::endl; - outputFile << "num decoding steps: " << profile_info.decoding_steps - << std::endl; - outputFile << "token IDs: "; - for (int i = 0; i < request.tokens.size(); i++) { - outputFile << request.tokens[i]; - if (i < request.tokens.size() - 1) { - outputFile << ","; + request.status = Request::COMPLETED; + log_req_mgr.print("[Done] guid(%zu) with final length(%zu)", + request.guid, + request.tokens.size()); + std::string output = this->tokenizer_->Decode(request.tokens); + { + // update generation result and trigger future + GenerationResult &gr = request_generation_results[request.guid]; + assert(gr.guid == request.guid); + gr.output_tokens = request.tokens; + gr.output_text = output; + } + log_req_mgr.print("Final output: %s", output.c_str()); + + new_bc.request_completed[i] = true; + new_bc.request_running[i] = false; + num_processed_requests++; + + // Log profiling info + ProfileInfo profile_info = profiling_requests[request.guid]; + profile_info.finish_time = Realm::Clock::current_time_in_microseconds(); + total_request_run_time += + profile_info.finish_time - profile_info.start_time; + profiling_requests[request.guid] = profile_info; + log_req_mgr.print("[Profile] guid(%zu) decoding_steps(%d) start(%.1lf) " + "finish(%.1lf) latency(%.1lf)", + request.guid, + profile_info.decoding_steps, + profile_info.start_time, + profile_info.finish_time, + profile_info.finish_time - profile_info.start_time); + + // Write output to file if needed: + if (!output_filepath.empty()) { + std::ofstream outputFile(output_filepath); + if (outputFile.is_open()) { + outputFile << "end-to-end latency: " << std::fixed + << std::setprecision(3) + << profile_info.finish_time - profile_info.start_time + << std::endl; + outputFile << "num decoding steps: " << profile_info.decoding_steps + << std::endl; + outputFile << "token IDs: "; + for (int i = 0; i < request.tokens.size(); i++) { + outputFile << request.tokens[i]; + if (i < request.tokens.size() - 1) { + outputFile << ","; + } } + outputFile << std::endl; + outputFile << output; + outputFile.close(); + } else { + std::cout << "Unable to open the output file: " << output_filepath + << std::endl; + assert(false); } - outputFile << std::endl; - outputFile << output; - outputFile.close(); - } else { - std::cout << "Unable to open the output file: " << output_filepath - << std::endl; - assert(false); } - } - // delete the old input tree from cache - dfs_tree_inputs.erase(request.guid); + // delete the old input tree from cache + dfs_tree_inputs.erase(request.guid); - continue; - } + } else { // Request not finished, pass verified_tokens to next iteration - new_bc.request_completed[i] = false; - - // Normal Request Info - new_bc.requestsInfo[i].token_start_offset = verified_tokens.front().second; - new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; - new_bc.requestsInfo[i].max_sequence_length = - old_bc.requestsInfo[i].max_sequence_length; - new_bc.requestsInfo[i].num_tokens_in_batch = verified_tokens.size(); - - // TODO: Beam Request Info, missing from VerifyTreeBatchConfig - int new_max_depth = new_bc.requestsInfo[i].max_sequence_length - - new_bc.requestsInfo[i].token_start_offset - - verified_tokens.size(); - new_bc.beamRequestsInfo[i].current_depth = 1; - new_bc.beamRequestsInfo[i].beam_size = - BeamSearchBatchConfig::MAX_BEAM_WIDTH; - new_bc.beamRequestsInfo[i].max_depth = - std::min(new_max_depth, BeamSearchBatchConfig::MAX_BEAM_DEPTH); - for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) { - new_bc.beamRequestsInfo[i].parent_id[j] = 0; - new_bc.beamRequestsInfo[i].probs[j] = 1; - } + new_bc.request_completed[i] = false; + new_bc.request_running[i] = true; + + // Normal Request Info + new_bc.requestsInfo[i].token_start_offset = + verified_tokens.front().second; + new_bc.requestsInfo[i].request_guid = + old_bc.requestsInfo[i].request_guid; + new_bc.requestsInfo[i].max_sequence_length = + old_bc.requestsInfo[i].max_sequence_length; + new_bc.requestsInfo[i].num_tokens_in_batch = verified_tokens.size(); - new_bc.sub_requests[i] = 1; + // TODO: Beam Request Info, missing from VerifyTreeBatchConfig + int new_max_depth = new_bc.requestsInfo[i].max_sequence_length - + new_bc.requestsInfo[i].token_start_offset - + verified_tokens.size(); + new_bc.beamRequestsInfo[i].current_depth = 1; + new_bc.beamRequestsInfo[i].beam_size = + BeamSearchBatchConfig::MAX_BEAM_WIDTH; + new_bc.beamRequestsInfo[i].max_depth = + std::min(new_max_depth, BeamSearchBatchConfig::MAX_BEAM_DEPTH); + for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) { + new_bc.beamRequestsInfo[i].parent_id[j] = 0; + new_bc.beamRequestsInfo[i].probs[j] = 1; + } - // Token Info - for (int j = 0; j < verified_tokens.size(); j++) { - auto token = verified_tokens.at(j); + new_bc.sub_requests[i] = 1; - // Normal Token Info - new_bc.tokensInfo[new_bc.num_tokens].request_index = i; - new_bc.tokensInfo[new_bc.num_tokens].token_id = token.first; - new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = token.second; + // Token Info + for (int j = 0; j < verified_tokens.size(); j++) { + auto token = verified_tokens.at(j); - // Beam Token Info - new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = 0; - new_bc.num_tokens++; + // Normal Token Info + new_bc.tokensInfo[new_bc.num_tokens].request_index = i; + new_bc.tokensInfo[new_bc.num_tokens].token_id = token.first; + new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = + token.second; - // Add verified token to request's token list - request.tokens.push_back(token.first); + // Beam Token Info + new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = 0; + new_bc.num_tokens++; - if (new_bc.num_tokens == BatchConfig::MAX_NUM_TOKENS) { - break; + // Add verified token to request's token list + request.tokens.push_back(token.first); + + if (new_bc.num_tokens == BatchConfig::MAX_NUM_TOKENS) { + break; + } + } + std::string output = this->tokenizer_->Decode(request.tokens); + log_req_mgr.print("Output: %s", output.c_str()); } + } else if (request.status == Request::PENDING) { + new_bc.request_completed[i] = false; + new_bc.request_running[i] = false; + + std::cout << "ssm_cache_size: " << request.ssm_cache_size << ", " + << "initial_len: " << request.initial_len << std::endl; + assert(request.ssm_cache_size == request.initial_len); + + // Normal Request Info + new_bc.requestsInfo[i].token_start_offset = request.ssm_cache_size; + new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; + new_bc.requestsInfo[i].max_sequence_length = + old_bc.requestsInfo[i].max_sequence_length; + new_bc.requestsInfo[i].num_tokens_in_batch = 0; + + // TODO: Beam Request Info, missing from VerifyTreeBatchConfig + new_bc.beamRequestsInfo[i].current_depth = 1; + new_bc.beamRequestsInfo[i].beam_size = + BeamSearchBatchConfig::MAX_BEAM_WIDTH; + new_bc.beamRequestsInfo[i].max_depth = 0; + for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) { + new_bc.beamRequestsInfo[i].parent_id[j] = 0; + new_bc.beamRequestsInfo[i].probs[j] = 1; + } + + new_bc.sub_requests[i] = 1; + + // Token Info + std::string output = this->tokenizer_->Decode(request.tokens); + log_req_mgr.print("Output: %s", output.c_str()); + } else { + assert(false); } - std::string output = this->tokenizer_->Decode(request.tokens); - log_req_mgr.print("Output: %s", output.c_str()); } // Step 2: Initialize new request - new_bc.max_init_length = 0; for (int i = 0; i < BeamSearchBatchConfig::MAX_NUM_REQUESTS; i++) { if (new_bc.request_completed[i]) { if (!pending_request_queue.empty() && new_bc.num_tokens < BeamSearchBatchConfig::MAX_NUM_TOKENS) { Request new_request = pending_request_queue.front(); pending_request_queue.pop(); - new_bc.max_init_length = - std::max(new_bc.max_init_length, new_request.initial_len); // all_requests[new_request.guid] = new_request; new_bc.requestsInfo[i].token_start_offset = 0; new_bc.requestsInfo[i].request_guid = new_request.guid; @@ -886,6 +780,33 @@ BeamSearchBatchConfig new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = 0; new_bc.num_tokens++; } + + // if (new_bc.requestsInfo[i].num_tokens_in_batch < + // new_request.initial_len) { + // all_requests[new_request.guid].status = Request::PENDING; + // new_bc.request_running[i] = false; + // std::cout << "Request " << new_request.guid << " is pending" + // << std::endl; + // } else { + // all_requests[new_request.guid].status = Request::RUNNING; + // new_bc.request_running[i] = true; + // std::cout << "Request " << new_request.guid << " is running" + // << std::endl; + // } + all_requests[new_request.guid].status = Request::PENDING; + all_requests[new_request.guid].ssm_cache_size = + new_bc.requestsInfo[i].num_tokens_in_batch; + new_bc.request_running[i] = false; + std::cout << "SSM KV Cache Size init: " + << all_requests[new_request.guid].ssm_cache_size << std::endl; + std::cout << "LLM KV Cache Size init: " + << all_requests[new_request.guid].llm_cache_size << std::endl; + + std::cout << "load " << new_bc.requestsInfo[i].num_tokens_in_batch + << " tokens for request " << new_request.guid << std::endl; + std::cout << "total prompt in request: " << new_request.initial_len + << std::endl; + if (new_bc.num_tokens == BatchConfig::MAX_NUM_TOKENS) { break; } @@ -902,6 +823,209 @@ BeamSearchBatchConfig return new_bc; } +/***** Beam Search Phase *****/ +BeamSearchBatchConfigFuture RequestManager::prepare_next_batch_beam( + BeamSearchBatchConfigFuture const &old_bc, + BeamInferenceResultFuture const &result) { + Runtime *runtime = Runtime::get_runtime(); + Context ctx = Runtime::get_context(); + + RequestManager *rm = this; + TaskLauncher launcher(RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID, + TaskArgument(&rm, sizeof(RequestManager *))); + launcher.add_future(old_bc); + launcher.add_future(result); + return runtime->execute_task(ctx, launcher); +} + +BeamSearchBatchConfig RequestManager::prepare_next_batch_beam_task( + Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + RequestManager *rm = *((RequestManager **)task->args); + BeamSearchBatchConfig const &bc = + Future(task->futures[0]).get_result(); + BeamInferenceResult const &result = + Future(task->futures[1]).get_result(); + return rm->prepare_next_batch_beam(bc, result); +} + +// update beam search metadata +BeamSearchBatchConfig + RequestManager::prepare_next_batch_beam(BeamSearchBatchConfig const &old_bc, + BeamInferenceResult const &result) { + const std::lock_guard lock(request_queue_mutex); + if (verbose) { + std::cout << "\n############### prepare_next_batch_beam ###############\n"; + } + if (verbose) { + std::cout << "print all results" + << "\n"; + for (int i = 0; i < 40; i++) { + std::cout << result.token_ids[i] << ", "; + } + std::cout << "Current Beam Depth: " + << old_bc.beamRequestsInfo[0].current_depth << "\n"; + } + // Step 1: Store result to the beam tree struct + store_beam_metadata(old_bc, result); + + // Step 2: preparing the next batch for existing requests + BeamSearchBatchConfig new_bc; + new_bc.model_id = old_bc.model_id; + // std::cout << "old_bc.model_id: " << old_bc.model_id << "\n"; + + for (int i = 0; i < BatchConfig::MAX_NUM_REQUESTS; i++) { + if (old_bc.request_completed[i]) { + continue; + } + // Comment out this assertion since num_tokens_in_batch can be + // zero when beam search has reached required sequence length + // assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0); + Request &request = all_requests[old_bc.requestsInfo[i].request_guid]; + int processed_tokens = old_bc.requestsInfo[i].token_start_offset + + old_bc.requestsInfo[i].num_tokens_in_batch; + + // assert(processed_tokens < request.tokens.size()); + log_req_mgr.debug() << "processed_tokens: " << processed_tokens << "\n"; + // if (processed_tokens > + // old_bc.beamRequestsInfo[i].max_depth + request.tokens.size() && + // request.status == Request::RUNNING + // // || ir.results[t] == 0 TODO: replace this with + // ) { + // // log_req_mgr.print("[Done] guid(%zu) with spec_tree_depth(%d)", + // // old_bc.requestsInfo[i].request_guid, + // // old_bc.beamRequestsInfo[i].max_depth); + // // // new_bc.request_completed[i] = true; + // // new_bc.request_completed[i] = false; + // // new_bc.requestsInfo[i].token_start_offset = processed_tokens; + // // new_bc.requestsInfo[i].request_guid = + // // old_bc.requestsInfo[i].request_guid; + // // new_bc.requestsInfo[i].max_sequence_length = + // // old_bc.requestsInfo[i].max_sequence_length; + // // new_bc.beamRequestsInfo[i].current_depth = + // // old_bc.beamRequestsInfo[i].current_depth; + // // new_bc.request_running[i] = false; + // std::cout << "beam search end:" << request.status << i << ", " + // << new_bc.requestsInfo[i].num_tokens_in_batch << "\n"; + // } + // else + { + log_req_mgr.debug() << "num tokens: " << old_bc.num_tokens << ", " + << new_bc.num_tokens; + new_bc.request_completed[i] = false; + new_bc.requestsInfo[i].token_start_offset = processed_tokens; + new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; + new_bc.requestsInfo[i].max_sequence_length = + old_bc.requestsInfo[i].max_sequence_length; + + // update the beam search metadata + // how many sub request in current request + // why is sub_requests has MAX_NUM_REQUESTS * MAX_BEAM_WIDTH entries? + new_bc.sub_requests[i] = old_bc.beamRequestsInfo[i].beam_size; + + // update the parentid, accumalated_probs, depth, and token_ids + new_bc.beamRequestsInfo[i].beam_size = + old_bc.beamRequestsInfo[i].beam_size; + new_bc.beamRequestsInfo[i].max_depth = + old_bc.beamRequestsInfo[i].max_depth; + if (request.status == Request::RUNNING) { + new_bc.beamRequestsInfo[i].current_depth = + old_bc.beamRequestsInfo[i].current_depth + 1; + new_bc.request_running[i] = true; + // do the slot exchange to minimize the cache exchange in kernel. + update_beam_metadata(new_bc, request.beam_trees.at(old_bc.model_id), i); + } else { + // if the request is pending, we need to update the beam search + // metadata based on the initial length + new_bc.beamRequestsInfo[i].current_depth = + old_bc.beamRequestsInfo[i].current_depth; + new_bc.request_running[i] = false; + } + + // do the slot exchange to minimize the cache exchange in kernel. + // update_beam_metadata(new_bc, request.beam_trees.at(old_bc.model_id), + // i); + if (new_bc.requestsInfo[i].token_start_offset >= request.tokens.size()) { + // Incremental phase + if (request.status == Request::RUNNING) { + new_bc.requestsInfo[i].num_tokens_in_batch = 1; + } else { + new_bc.requestsInfo[i].num_tokens_in_batch = 0; + } + + if (verbose) { + std::cout << "[ Beam Spec] " << request.guid << std::endl; + std::cout << "Incremental phase: " << request.tokens.size() + << ", num_tokens_in_batch: " + << new_bc.requestsInfo[i].num_tokens_in_batch << std::endl; + } + } else { + // Prompt phase + new_bc.requestsInfo[i].num_tokens_in_batch = + // std::min(BatchConfig::MAX_NUM_TOKENS - new_bc.num_tokens, + std::min(BatchConfig::MAX_NUM_TOKENS - new_bc.num_tokens - + BatchConfig::MAX_NUM_REQUESTS + i, + (int)request.tokens.size() - + new_bc.requestsInfo[i].token_start_offset); + request.ssm_cache_size += new_bc.requestsInfo[i].num_tokens_in_batch; + if (verbose) { + std::cout << "[ Beam Spec] " << request.guid << std::endl; + std::cout << "Prompt phase: " << request.tokens.size() + << ", num_tokens_in_batch:" + << new_bc.requestsInfo[i].num_tokens_in_batch << std::endl; + std::cout << "Update ssm cache size: " << request.ssm_cache_size + << std::endl; + } + } + + if (verbose) { + std::cout << "SSM KV Cache Size beam: " << request.ssm_cache_size + << std::endl; + std::cout << "LLM KV Cache Size beam: " << request.llm_cache_size + << std::endl; + } + + // register more tokens due to the beam width + for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { + int depth = new_bc.requestsInfo[i].token_start_offset + j; + for (int k = 0; k < new_bc.sub_requests[i]; k++) { + new_bc.tokensInfo[new_bc.num_tokens].request_index = i; + new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth; + + // get value from requestinfo + if (request.status == Request::RUNNING) { + // std::cout << "[running ]Num of token in batch: " + // << new_bc.requestsInfo[i].num_tokens_in_batch + // << std::endl; + new_bc.tokensInfo[new_bc.num_tokens].token_id = + new_bc.beamRequestsInfo[i].tokens[k]; + } else { + // std::cout << "[pending ]Num of token in batch: " + // << new_bc.requestsInfo[i].num_tokens_in_batch + // << std::endl; + new_bc.tokensInfo[new_bc.num_tokens].token_id = + request.tokens[request.tokens.size() - 1]; + } + + new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = k; + new_bc.num_tokens++; + } + } + } + } + if (verbose) { + std::cout << "prepare_next_batch_beam OLD vs NEW batchconfigs:" + << std::endl; + old_bc.print(); + new_bc.print(); + } + return new_bc; +} + +/***** Verify Phase *****/ + TreeVerifyBatchConfigFuture RequestManager::prepare_next_batch_verify( std::vector const &old_batches) { Runtime *runtime = Runtime::get_runtime(); @@ -943,6 +1067,17 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens_to_commit = 0; new_bc.num_tokens = 0; + int max_prompt_load_size = BatchConfig::MAX_NUM_TOKENS; + for (int i = 0; i < TreeVerifyBatchConfig::MAX_NUM_REQUESTS; i++) { + if (old_batches.at(0).request_completed[i]) { + continue; + } else if (old_batches.at(0).request_running[i]) { + max_prompt_load_size -= (BeamSearchBatchConfig::MAX_BEAM_DEPTH + 1); + } else { + max_prompt_load_size -= 1; + } + } + for (int i = 0; i < TreeVerifyBatchConfig::MAX_NUM_REQUESTS; i++) { if (old_batches.at(0).request_completed[i]) { continue; @@ -950,60 +1085,73 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( size_t guid = old_batches.at(0).requestsInfo[i].request_guid; Request &request = all_requests[guid]; - // Get the dfs tree - std::vector>> - all_dfs_trees; - - for (int j = 0; j < old_batches.size(); j++) { - std::vector> new_tree = - traverse_beam_tree(old_batches.at(j), i, request.tokens.size() - 1); - all_dfs_trees.push_back(new_tree); - } - assert(all_dfs_trees.size() == old_batches.size()); - std::vector> dfs_tree_inputs = - merge_dfs_trees(all_dfs_trees, request.tokens.size() - 1, guid); + // Profiling + profiling_requests[request.guid].decoding_steps += 1; - if (verbose) { - std::cout << "Request Tokens Size: " << request.tokens.size() + if (request.status == Request::RUNNING) { + new_bc.request_running[i] = true; + std::cout << "[Verify] Request " << request.guid << " is running" << std::endl; - for (int k = 0; k < request.tokens.size(); k++) { - std::cout << k << ": " << request.tokens[k] << std::endl; - } - } - // Normal Request Info - new_bc.requestsInfo[i].token_start_offset = dfs_tree_inputs.front().second; - new_bc.requestsInfo[i].request_guid = - old_batches.at(0).requestsInfo[i].request_guid; - new_bc.requestsInfo[i].max_sequence_length = - old_batches.at(0).requestsInfo[i].max_sequence_length; - // TODO: Check this - new_bc.requestsInfo[i].num_tokens_in_batch = 0; - new_bc.request_completed[i] = false; + // Get the dfs tree + std::vector>> + all_dfs_trees; - // Profiling - profiling_requests[new_bc.requestsInfo[i].request_guid].decoding_steps += 1; - // TODO: Add prompt token first in first verify iteration - if (request.tokens.size() == request.initial_len) { - // Initialization (prompt) phase - for (int j = 0; j < request.initial_len; j++) { - new_bc.tokensInfo[new_bc.num_tokens].request_index = i; - new_bc.tokensInfo[new_bc.num_tokens].token_id = request.tokens[j]; - new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = j; + for (int j = 0; j < old_batches.size(); j++) { + std::vector> new_tree = + traverse_beam_tree(old_batches.at(j), i, request.tokens.size() - 1); + all_dfs_trees.push_back(new_tree); + } + assert(all_dfs_trees.size() == old_batches.size()); + std::vector> dfs_tree_inputs = + merge_dfs_trees(all_dfs_trees, request.tokens.size() - 1, guid); - new_bc.num_tokens++; - new_bc.requestsInfo[i].num_tokens_in_batch++; + if (verbose) { + std::cout << "Request Tokens Size: " << request.tokens.size() + << std::endl; + for (int k = 0; k < request.tokens.size(); k++) { + std::cout << k << ": " << request.tokens[k] << std::endl; + } } - std::cout << "new_bc.num_tokens: " << new_bc.num_tokens << std::endl; - if (new_bc.num_tokens >= BatchConfig::MAX_NUM_TOKENS) { - assert(false && - "Exceeding the space available in the TreeVerify batch"); - break; + // Normal Request Info + new_bc.requestsInfo[i].token_start_offset = + dfs_tree_inputs.front().second; + new_bc.requestsInfo[i].request_guid = + old_batches.at(0).requestsInfo[i].request_guid; + new_bc.requestsInfo[i].max_sequence_length = + old_batches.at(0).requestsInfo[i].max_sequence_length; + // TODO: Check this + new_bc.requestsInfo[i].num_tokens_in_batch = 0; + new_bc.request_completed[i] = false; + + // Committed Tokens + if (committed_tokens.find(guid) != committed_tokens.end()) { + for (int j = 0; j < dfs_tree_inputs.size(); j++) { + if (j < committed_tokens.at(guid).size()) { + auto committed_token = committed_tokens.at(guid).at(j); + new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_index = + committed_token.second; + new_bc.committed_tokens[new_bc.num_tokens_to_commit].request_index = + i; + new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_depth = + committed_token.first; + if (verbose) { + std::cout << new_bc.num_tokens_to_commit + << "- committed_token.token_depth: " + << committed_token.first + << ", token_index: " << committed_token.second + << std::endl; + } + new_bc.num_tokens_to_commit++; + } + } + } + if (verbose) { + std::cout << "new_bc.num_tokens_to_commit: " + << new_bc.num_tokens_to_commit << std::endl; } - new_bc.requestsInfo[i].token_start_offset = 0; - } else { // Incremental phase: only add the last committed token new_bc.tokensInfo[new_bc.num_tokens].request_index = i; new_bc.tokensInfo[new_bc.num_tokens].token_id = request.tokens.back(); @@ -1013,116 +1161,124 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens++; new_bc.requestsInfo[i].num_tokens_in_batch++; - if (new_bc.num_tokens == BatchConfig::MAX_NUM_TOKENS) { + if (new_bc.num_tokens > BatchConfig::MAX_NUM_TOKENS) { assert(false && "Exceeding the space available in the TreeVerify batch"); break; } new_bc.requestsInfo[i].token_start_offset = request.tokens.size() - 1; - } - - if (verbose) { - std::cout << "dfs_tree_inputs.size(): " << dfs_tree_inputs.size() - << std::endl; - } - // add prompt to the dfs tree - if (committed_tokens.find(guid) != committed_tokens.end()) { - if (dfs_tree_inputs.at(0).second == - request.initial_len + committed_tokens.at(guid).size() - 1) { - for (int j = 0; j < request.initial_len; j++) { - new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_index = j; - new_bc.committed_tokens[new_bc.num_tokens_to_commit].request_index = - i; - new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_depth = j; - if (verbose) { - std::cout << new_bc.num_tokens_to_commit - << "- committed_token.token_depth: " << j - << ", token_index: " << j << std::endl; - } - new_bc.num_tokens_to_commit++; - } - } else { - // only add the root token - auto committed_token = committed_tokens.at(guid).at(0); - new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_index = - committed_token.second; - new_bc.committed_tokens[new_bc.num_tokens_to_commit].request_index = i; - new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_depth = - committed_token.first; + // Add Tokens from the DFS Tree to the next batch + for (int j = 1; j < dfs_tree_inputs.size(); j++) { + auto token = dfs_tree_inputs.at(j); if (verbose) { - std::cout << new_bc.num_tokens_to_commit - << "- committed_token.token_depth: " - << committed_token.first - << ", token_index: " << committed_token.second << std::endl; + std::cout << "[" << j << "] Token: " << token.first + << ", Depth:" << token.second << std::endl; } - new_bc.num_tokens_to_commit++; - } - if (verbose) { - std::cout << "new_bc.num_tokens_to_commit: " - << new_bc.num_tokens_to_commit << std::endl; - } - } + // Normal Token Info + new_bc.tokensInfo[new_bc.num_tokens].request_index = i; + new_bc.tokensInfo[new_bc.num_tokens].token_id = token.first; + new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = + token.second; - // Token Info - for (int j = 1; j < dfs_tree_inputs.size(); j++) { - auto token = dfs_tree_inputs.at(j); - if (verbose) { - std::cout << "[" << j << "] Token: " << token.first - << ", Depth:" << token.second << std::endl; + new_bc.num_tokens++; + new_bc.requestsInfo[i].num_tokens_in_batch++; + + if (new_bc.num_tokens == BatchConfig::MAX_NUM_TOKENS - 1) { + break; + } } - // Normal Token Info - new_bc.tokensInfo[new_bc.num_tokens].request_index = i; - new_bc.tokensInfo[new_bc.num_tokens].token_id = token.first; - new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = token.second; - // TODO: Add committed token info + } else if (request.status == Request::PENDING) { + new_bc.request_running[i] = false; if (verbose) { - std::cout << "committed_tokens.size(): " << new_bc.num_tokens_to_commit + std::cout << "[Verify] Request " << request.guid + << " is pending in loading prompt phase" << std::endl; + std::cout << "SSM KV Cache Size verify: " << request.ssm_cache_size + << std::endl; + std::cout << "LLM KV Cache Size verify: " << request.llm_cache_size << std::endl; } + // Commit all tokens from the last loading batch if (committed_tokens.find(guid) != committed_tokens.end()) { - if (j < committed_tokens.at(guid).size()) { - auto committed_token = committed_tokens.at(guid).at(j); + for (int j = 0; j < committed_tokens.at(guid).size(); j++) { + auto token = committed_tokens.at(guid).at(j); new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_index = - committed_token.second; + token.second; new_bc.committed_tokens[new_bc.num_tokens_to_commit].request_index = i; new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_depth = - committed_token.first; - if (verbose) { - std::cout << new_bc.num_tokens_to_commit - << "- committed_token.token_depth: " - << committed_token.first - << ", token_index: " << committed_token.second - << std::endl; - } + token.first; + new_bc.num_tokens_to_commit++; + request.llm_cache_size++; } - } - if (verbose) { - std::cout << "new_bc.num_tokens_to_commit: " + std::cout << "[Verify] Committed Tokens from last loading batch: " << new_bc.num_tokens_to_commit << std::endl; } - new_bc.num_tokens++; - new_bc.requestsInfo[i].num_tokens_in_batch++; + // Normal Request Info + new_bc.requestsInfo[i].token_start_offset = request.llm_cache_size; + new_bc.requestsInfo[i].request_guid = + old_batches.at(0).requestsInfo[i].request_guid; + new_bc.requestsInfo[i].max_sequence_length = + old_batches.at(0).requestsInfo[i].max_sequence_length; - if (new_bc.num_tokens == BatchConfig::MAX_NUM_TOKENS - 1) { - break; - } - } + new_bc.request_completed[i] = false; - std::cout << "new_bc.num_tokens: " << new_bc.num_tokens << std::endl; - } + new_bc.requestsInfo[i].num_tokens_in_batch = std::min( + max_prompt_load_size, + (int)request.initial_len - new_bc.requestsInfo[i].token_start_offset); + max_prompt_load_size -= new_bc.requestsInfo[i].num_tokens_in_batch; - if (verbose) { - std::cout << "prepare_next_batch_verify OLD vs NEW batchconfigs below:" - << std::endl; - // old_batches.print(); - // new_bc.print(); + std::cout << "max_prompt_load_size: " << max_prompt_load_size + << std::endl; + std::cout << "new_bc.requestsInfo[i].num_tokens_in_batch: " << i << ", " + << new_bc.requestsInfo[i].num_tokens_in_batch << std::endl; + + if (request.llm_cache_size < request.initial_len) { + // Initialization (prompt) phase + for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { + new_bc.tokensInfo[new_bc.num_tokens].request_index = i; + new_bc.tokensInfo[new_bc.num_tokens].token_id = + request.tokens[request.llm_cache_size + j]; + new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = + request.llm_cache_size + j; + + new_bc.num_tokens++; + } + + if (new_bc.num_tokens > BatchConfig::MAX_NUM_TOKENS) { + assert(false && + "Exceeding the space available in the TreeVerify batch"); + break; + } + } else { // launch the request into running phase after loading all prompt + if (BatchConfig::MAX_NUM_TOKENS - new_bc.num_tokens > 0) { + request.status = Request::RUNNING; + new_bc.request_running[i] = true; + + new_bc.tokensInfo[new_bc.num_tokens].request_index = i; + new_bc.tokensInfo[new_bc.num_tokens].token_id = request.tokens.back(); + new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = + request.tokens.size() - 1; + + new_bc.num_tokens++; + new_bc.requestsInfo[i].num_tokens_in_batch++; + std::cout << "new_bc.requestsInfo[i].num_tokens_in_batch: " + << new_bc.requestsInfo[i].num_tokens_in_batch << std::endl; + + dfs_tree_inputs[guid] = + std::vector>{std::make_pair( + request.tokens.back(), request.tokens.size() - 1)}; + } + } + + } else { + assert(false && "Request status is not RUNNING or PENDING"); + } } return new_bc; @@ -1145,14 +1301,16 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, } for (int i = 0; i <= old_bc.num_tokens; i++) { - int request_index = old_bc.tokensInfo[i].request_index; - - // End of the request if (i == old_bc.num_tokens || - old_bc.requestsInfo[request_index].request_guid != guid) { + old_bc.requestsInfo[old_bc.tokensInfo[i].request_index].request_guid != + guid) { + + int index = old_bc.tokensInfo[i - 1].request_index; + int beam_size = old_bc.beamRequestsInfo[index].beam_size; + int depth = old_bc.beamRequestsInfo[index].current_depth; // Each token yields (beam_width) results - int beam_width = old_bc.beamRequestsInfo[request_index].beam_size; + int beam_width = old_bc.beamRequestsInfo[index].beam_size; // Count tokens sent to model in this request to find the final token's // index @@ -1165,10 +1323,6 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, << ", value: " << result.token_ids[result_index] << "\n"; } - int index = old_bc.tokensInfo[i - 1].request_index; - int beam_size = old_bc.beamRequestsInfo[index].beam_size; - int depth = old_bc.beamRequestsInfo[index].current_depth; - Request &request = all_requests[old_bc.requestsInfo[index].request_guid]; if (depth == 1) { @@ -1212,7 +1366,7 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, // update the guid and start_depth for current request if (i < old_bc.num_tokens) { - guid = old_bc.requestsInfo[request_index].request_guid; + guid = old_bc.requestsInfo[index].request_guid; start_depth = old_bc.tokensInfo[i].abs_depth_in_request; } } @@ -1585,24 +1739,27 @@ std::vector> return merged_tree; } -GenerationResult FFModel::generate(std::string const &text, +GenerationResult FFModel::generate(std::vector &prompts, int max_seq_length) { RequestManager *rm = RequestManager::get_request_manager(); if (rm->get_num_ssms() == 0) { // No SSMs: perform incremental decoding - return rm->generate_incr_decoding(this, text, max_seq_length); + return rm->generate_incr_decoding(this, prompts, max_seq_length); } else { // Registered SSMs: perform speculative inference - return rm->generate_spec_infer(this, text, max_seq_length); + return rm->generate_spec_infer(this, prompts, max_seq_length); } } /*static*/ -GenerationResult RequestManager::generate_incr_decoding(FFModel *llm, - std::string const &text, - int max_seq_length) { +GenerationResult RequestManager::generate_incr_decoding( + FFModel *llm, std::vector &prompts, int max_seq_length) { InferenceManager *im = InferenceManager::get_inference_manager(); - RequestGuid guid = register_new_request(text, max_seq_length); + RequestGuid guid; + for (int i = 0; i < prompts.size(); i++) { + guid = register_new_request(prompts.at(i), max_seq_length); + } + if (guid == 0) { std::cout << "=========== Discard request exceed prompt maximum... ===========" @@ -1652,11 +1809,13 @@ GenerationResult RequestManager::generate_incr_decoding(FFModel *llm, } /*static*/ -GenerationResult RequestManager::generate_spec_infer(FFModel *llm, - std::string const &text, - int max_seq_length) { +GenerationResult RequestManager::generate_spec_infer( + FFModel *llm, std::vector &prompts, int max_seq_length) { InferenceManager *im = InferenceManager::get_inference_manager(); - RequestGuid guid = register_new_request(text, max_seq_length); + RequestGuid guid; + for (int i = 0; i < prompts.size(); i++) { + guid = register_new_request(prompts.at(i), max_seq_length); + } if (guid == 0) { std::cout << "=========== Discard request exceed prompt maximum... ==========="