Skip to content

Commit

Permalink
feat: add two naive scheduling policies
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzhuofu committed Oct 16, 2024
1 parent 9cf66c1 commit c8d442d
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 0 deletions.
9 changes: 9 additions & 0 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,11 @@ class RequestManager {
void
set_slo_violation_early_termination(bool slo_violation_early_termination);
void set_spec_infer_old_version(bool spec_infer_old_version);
void set_greedy_schedule(bool greedy_schedule);
void set_equal_schedule(bool equal_schedule);
bool get_spec_infer_old_version();
bool get_greedy_schedule();
bool get_equal_schedule();
double get_request_expected_latency(Request &request);
Request &get_request_with_guid(RequestGuid guid);
int register_ssm_model(FFModel *model);
Expand Down Expand Up @@ -403,6 +407,8 @@ class RequestManager {
bool memory_occupancy = false;
bool slo_violation_early_termination = false;
bool spec_infer_old_version = false;
bool greedy_schedule = false;
bool equal_schedule = false;

std::unique_ptr<Tokenizer> tokenizer_;
bool verbose;
Expand Down Expand Up @@ -526,9 +532,12 @@ class RequestManager {
void add_tokens_to_spec_token_tree_old_version(
InferenceResult const &ssm_inference_result);
void prune_token_tree();
void prune_token_tree_equal();
void prune_token_tree_greedy();
void add_tokens_toward_slo(RequestGuid guid, int &budget);
void add_tokens_toward_memory_occupancy(int budget);
void add_tokens_toward_goodput(int budget);
void add_tokens_toward_goodput_per_request(int budget, int request_index);
void update_token_tree_depth();

/* ---------- Spec Decoding Helper Functions ---------- */
Expand Down
16 changes: 16 additions & 0 deletions inference/spec_infer/spec_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ void parse_input_args(char **argv,
int &llm_verify_latency_ms,
double &request_per_second,
bool &spec_infer_old_version,
bool &greedy_schedule,
bool &equal_schedule,
std::string &emission_file_path) {
for (int i = 1; i < argc; i++) {
// llm model name
Expand Down Expand Up @@ -206,6 +208,14 @@ void parse_input_args(char **argv,
spec_infer_old_version = true;
continue;
}
if (!strcmp(argv[i], "--greedy-schedule")) {
greedy_schedule = true;
continue;
}
if (!strcmp(argv[i], "--equal-schedule")) {
equal_schedule = true;
continue;
}
if (!strcmp(argv[i], "--emission-file-path")) {
emission_file_path = std::string(argv[++i]);
continue;
Expand Down Expand Up @@ -383,6 +393,8 @@ void FlexFlow::top_level_task(Task const *task,
int llm_verify_latency_ms = 50;
double request_per_second = 1.0;
bool spec_infer_old_version = false;
bool greedy_schedule = false;
bool equal_schedule = false;
std::string emission_file_path;

InputArgs const &command_args = HighLevelRuntime::get_input_args();
Expand Down Expand Up @@ -413,6 +425,8 @@ void FlexFlow::top_level_task(Task const *task,
llm_verify_latency_ms,
request_per_second,
spec_infer_old_version,
greedy_schedule,
equal_schedule,
emission_file_path);
if (max_tokens_per_ssm_batch == -1) {
max_tokens_per_ssm_batch = max_tokens_per_batch;
Expand Down Expand Up @@ -460,6 +474,8 @@ void FlexFlow::top_level_task(Task const *task,
rm->set_ssm_spec_latency(ssm_spec_latency_ms);
rm->set_llm_verify_latency(llm_verify_latency_ms);
rm->set_spec_infer_old_version(spec_infer_old_version);
rm->set_greedy_schedule(greedy_schedule);
rm->set_equal_schedule(equal_schedule);
rm->register_output_filepath(file_paths.output_file_path);

// Create LLM model
Expand Down
94 changes: 94 additions & 0 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,26 @@ void RequestManager::set_spec_infer_old_version(bool spec_infer_old_version_) {
spec_infer_old_version = spec_infer_old_version_;
}

void RequestManager::set_greedy_scheduler(bool greedy_scheduler_) {
greedy_scheduler = greedy_scheduler_;
}

void RequestManager::set_equal_schedule(bool equal_schedule_) {
equal_schedule = equal_schedule_;
}

bool RequestManager::get_spec_infer_old_version() {
return spec_infer_old_version;
}

bool RequestManager::get_greedy_scheduler() {
return greedy_scheduler;
}

bool RequestManager::get_equal_schedule() {
return equal_schedule;
}

double RequestManager::get_request_expected_latency(Request &request) {
return request.get_slo_ratio() * baseline_latency_ms *
request.decode_length();
Expand Down Expand Up @@ -2888,6 +2904,12 @@ void RequestManager::add_tokens_to_spec_token_tree_old_version(
}

void RequestManager::prune_token_tree() {
if (get_greedy_schedule()) {
return prune_token_tree_greedy();
} else if (get_equal_schedule()) {
return prune_token_tree_equal();
}

// Each reqeust has at least one token
int budget = get_max_tokens_per_batch() - num_available_requests;
assert(budget >= 0);
Expand Down Expand Up @@ -2933,6 +2955,48 @@ void RequestManager::prune_token_tree() {
}
}

void RequestManager::prune_token_tree_equal() {
// Each reqeust has at least one token
int const equal_budget =
get_max_tokens_per_batch() / get_num_active_requests();
assert(equal_budget >= 0);

for (int request_index = 0; request_index < get_max_requests_per_batch();
++request_index) {
if (!request_available[request_index]) {
continue;
}
RequestGuid guid = guid_of_requests[request_index];
Request &request = all_requests[guid];
assert(request.status == Request::RUNNING);
int budget = equal_budget;
assert(budget >= 0);
if (budget > 0) {
add_tokens_toward_goodput_per_request(budget, request_index);
}
}
}

void RequestManager::prune_token_tree_greedy() {
// Each reqeust has at least one token
int budget = get_max_tokens_per_batch();
assert(budget >= 0);

for (int request_index = 0; request_index < get_max_requests_per_batch();

This comment has been minimized.

Copy link
@zikun-li

zikun-li Oct 16, 2024

Collaborator

This part can be removed.

++request_index) {
if (!request_available[request_index]) {
continue;
}
RequestGuid guid = guid_of_requests[request_index];
Request &request = all_requests[guid];
assert(request.status == Request::RUNNING);
}

if (budget > 0) {
add_tokens_toward_goodput(budget);
}
}

void RequestManager::add_tokens_toward_slo(RequestGuid guid, int &budget) {
Request &request = all_requests[guid];
double num_tokens_to_decode = (ssm_spec_latency_ms + llm_verify_latency_ms) *
Expand Down Expand Up @@ -3108,6 +3172,36 @@ void RequestManager::add_tokens_toward_goodput(int budget) {
}
}

void RequestManager::add_tokens_toward_goodput_per_request(int budget,
int request_index) {
RequestGuid guid = guid_of_requests[request_index];
Request &request = all_requests[guid];
assert(request.status == Request::RUNNING);
if (request.token_tree_nodes_acc_prob_pair_pq.empty()) {
continue;
}

auto &pq = request.token_tree_nodes_acc_prob_pair_pq;

// Perform dequeue and enqueue until the budget is used up
while (budget > 0 and !pq.empty()) {
auto [node_ptr, acc_log_prob] = pq.top();
pq.pop();
node_ptr->included = true;
budget--;
}

// Clear the priority queue in each requests
std::vector<std::pair<std::shared_ptr<TokenTreeNode>, double>>
_prealloc_vector;
_prealloc_vector.reserve(BatchConfig::MAX_SPEC_TREE_TOKEN_NUM);
request.token_tree_nodes_acc_prob_pair_pq = std::priority_queue<
std::pair<std::shared_ptr<TokenTreeNode>, double>,
std::vector<std::pair<std::shared_ptr<TokenTreeNode>, double>>,
SharedTokenTreeNodePtrDoubleLess>(SharedTokenTreeNodePtrDoubleLess(),
std::move(_prealloc_vector));
}

std::ostream &operator<<(std::ostream &os, TokenTree const &token_tree) {
os << "Token tree: " << std::endl;
int layer_idx = 0;
Expand Down

0 comments on commit c8d442d

Please sign in to comment.