From c8d442d10fa6b5aa78ac30b4e49a2cb3f05f5a23 Mon Sep 17 00:00:00 2001 From: fruitea Date: Wed, 16 Oct 2024 22:35:50 +0000 Subject: [PATCH] feat: add two naive scheduling policies --- include/flexflow/request_manager.h | 9 +++ inference/spec_infer/spec_infer.cc | 16 +++++ src/runtime/request_manager.cc | 94 ++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+) diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index ebfd31b3d3..c151cdfbc4 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -297,7 +297,11 @@ class RequestManager { void set_slo_violation_early_termination(bool slo_violation_early_termination); void set_spec_infer_old_version(bool spec_infer_old_version); + void set_greedy_schedule(bool greedy_schedule); + void set_equal_schedule(bool equal_schedule); bool get_spec_infer_old_version(); + bool get_greedy_schedule(); + bool get_equal_schedule(); double get_request_expected_latency(Request &request); Request &get_request_with_guid(RequestGuid guid); int register_ssm_model(FFModel *model); @@ -403,6 +407,8 @@ class RequestManager { bool memory_occupancy = false; bool slo_violation_early_termination = false; bool spec_infer_old_version = false; + bool greedy_schedule = false; + bool equal_schedule = false; std::unique_ptr tokenizer_; bool verbose; @@ -526,9 +532,12 @@ class RequestManager { void add_tokens_to_spec_token_tree_old_version( InferenceResult const &ssm_inference_result); void prune_token_tree(); + void prune_token_tree_equal(); + void prune_token_tree_greedy(); void add_tokens_toward_slo(RequestGuid guid, int &budget); void add_tokens_toward_memory_occupancy(int budget); void add_tokens_toward_goodput(int budget); + void add_tokens_toward_goodput_per_request(int budget, int request_index); void update_token_tree_depth(); /* ---------- Spec Decoding Helper Functions ---------- */ diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index e3fe4a2508..7b6bad7c4d 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -81,6 +81,8 @@ void parse_input_args(char **argv, int &llm_verify_latency_ms, double &request_per_second, bool &spec_infer_old_version, + bool &greedy_schedule, + bool &equal_schedule, std::string &emission_file_path) { for (int i = 1; i < argc; i++) { // llm model name @@ -206,6 +208,14 @@ void parse_input_args(char **argv, spec_infer_old_version = true; continue; } + if (!strcmp(argv[i], "--greedy-schedule")) { + greedy_schedule = true; + continue; + } + if (!strcmp(argv[i], "--equal-schedule")) { + equal_schedule = true; + continue; + } if (!strcmp(argv[i], "--emission-file-path")) { emission_file_path = std::string(argv[++i]); continue; @@ -383,6 +393,8 @@ void FlexFlow::top_level_task(Task const *task, int llm_verify_latency_ms = 50; double request_per_second = 1.0; bool spec_infer_old_version = false; + bool greedy_schedule = false; + bool equal_schedule = false; std::string emission_file_path; InputArgs const &command_args = HighLevelRuntime::get_input_args(); @@ -413,6 +425,8 @@ void FlexFlow::top_level_task(Task const *task, llm_verify_latency_ms, request_per_second, spec_infer_old_version, + greedy_schedule, + equal_schedule, emission_file_path); if (max_tokens_per_ssm_batch == -1) { max_tokens_per_ssm_batch = max_tokens_per_batch; @@ -460,6 +474,8 @@ void FlexFlow::top_level_task(Task const *task, rm->set_ssm_spec_latency(ssm_spec_latency_ms); rm->set_llm_verify_latency(llm_verify_latency_ms); rm->set_spec_infer_old_version(spec_infer_old_version); + rm->set_greedy_schedule(greedy_schedule); + rm->set_equal_schedule(equal_schedule); rm->register_output_filepath(file_paths.output_file_path); // Create LLM model diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 1e3335d285..0b3b8aaa2c 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -320,10 +320,26 @@ void RequestManager::set_spec_infer_old_version(bool spec_infer_old_version_) { spec_infer_old_version = spec_infer_old_version_; } +void RequestManager::set_greedy_scheduler(bool greedy_scheduler_) { + greedy_scheduler = greedy_scheduler_; +} + +void RequestManager::set_equal_schedule(bool equal_schedule_) { + equal_schedule = equal_schedule_; +} + bool RequestManager::get_spec_infer_old_version() { return spec_infer_old_version; } +bool RequestManager::get_greedy_scheduler() { + return greedy_scheduler; +} + +bool RequestManager::get_equal_schedule() { + return equal_schedule; +} + double RequestManager::get_request_expected_latency(Request &request) { return request.get_slo_ratio() * baseline_latency_ms * request.decode_length(); @@ -2888,6 +2904,12 @@ void RequestManager::add_tokens_to_spec_token_tree_old_version( } void RequestManager::prune_token_tree() { + if (get_greedy_schedule()) { + return prune_token_tree_greedy(); + } else if (get_equal_schedule()) { + return prune_token_tree_equal(); + } + // Each reqeust has at least one token int budget = get_max_tokens_per_batch() - num_available_requests; assert(budget >= 0); @@ -2933,6 +2955,48 @@ void RequestManager::prune_token_tree() { } } +void RequestManager::prune_token_tree_equal() { + // Each reqeust has at least one token + int const equal_budget = + get_max_tokens_per_batch() / get_num_active_requests(); + assert(equal_budget >= 0); + + for (int request_index = 0; request_index < get_max_requests_per_batch(); + ++request_index) { + if (!request_available[request_index]) { + continue; + } + RequestGuid guid = guid_of_requests[request_index]; + Request &request = all_requests[guid]; + assert(request.status == Request::RUNNING); + int budget = equal_budget; + assert(budget >= 0); + if (budget > 0) { + add_tokens_toward_goodput_per_request(budget, request_index); + } + } +} + +void RequestManager::prune_token_tree_greedy() { + // Each reqeust has at least one token + int budget = get_max_tokens_per_batch(); + assert(budget >= 0); + + for (int request_index = 0; request_index < get_max_requests_per_batch(); + ++request_index) { + if (!request_available[request_index]) { + continue; + } + RequestGuid guid = guid_of_requests[request_index]; + Request &request = all_requests[guid]; + assert(request.status == Request::RUNNING); + } + + if (budget > 0) { + add_tokens_toward_goodput(budget); + } +} + void RequestManager::add_tokens_toward_slo(RequestGuid guid, int &budget) { Request &request = all_requests[guid]; double num_tokens_to_decode = (ssm_spec_latency_ms + llm_verify_latency_ms) * @@ -3108,6 +3172,36 @@ void RequestManager::add_tokens_toward_goodput(int budget) { } } +void RequestManager::add_tokens_toward_goodput_per_request(int budget, + int request_index) { + RequestGuid guid = guid_of_requests[request_index]; + Request &request = all_requests[guid]; + assert(request.status == Request::RUNNING); + if (request.token_tree_nodes_acc_prob_pair_pq.empty()) { + continue; + } + + auto &pq = request.token_tree_nodes_acc_prob_pair_pq; + + // Perform dequeue and enqueue until the budget is used up + while (budget > 0 and !pq.empty()) { + auto [node_ptr, acc_log_prob] = pq.top(); + pq.pop(); + node_ptr->included = true; + budget--; + } + + // Clear the priority queue in each requests + std::vector, double>> + _prealloc_vector; + _prealloc_vector.reserve(BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); + request.token_tree_nodes_acc_prob_pair_pq = std::priority_queue< + std::pair, double>, + std::vector, double>>, + SharedTokenTreeNodePtrDoubleLess>(SharedTokenTreeNodePtrDoubleLess(), + std::move(_prealloc_vector)); +} + std::ostream &operator<<(std::ostream &os, TokenTree const &token_tree) { os << "Token tree: " << std::endl; int layer_idx = 0;