Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specscheduler new scheduler #1492

Merged
merged 13 commits into from
Sep 7, 2024
181 changes: 112 additions & 69 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,70 @@ class InferenceManager {
std::unordered_map<FFModel *, FileDataLoader *> model_weights_loaders;
};

class TokenTreeNode {
public:
BatchConfig::TokenId id;
float log_accumulated_prob;
int parent_pos;
bool included = false;
bool gumbel = false;
float gumbel_logit = 0.0f;

TokenTreeNode(BatchConfig::TokenId id,
float log_accumulated_prob,
int parent_pos,
bool gumbel = false,
float gumbel_logit = 0.0f)
: id(id), log_accumulated_prob(log_accumulated_prob),
parent_pos(parent_pos), gumbel(gumbel), gumbel_logit(gumbel_logit) {}
};

bool operator<(std::shared_ptr<TokenTreeNode> const &lhs,
std::shared_ptr<TokenTreeNode> const &rhs);

bool operator<=(std::shared_ptr<TokenTreeNode> const &lhs,
std::shared_ptr<TokenTreeNode> const &rhs);

// A comparator for std::shared_ptr<TokenTreeNode>
// This is used to sort the token tree nodes in ascending order
struct SharedTokenTreeNodePtrLess {
bool operator()(std::shared_ptr<TokenTreeNode> const &lhs,
std::shared_ptr<TokenTreeNode> const &rhs) const {
if (lhs->gumbel) {
assert(rhs->gumbel);
return lhs->gumbel_logit < rhs->gumbel_logit;
}
return lhs->log_accumulated_prob < rhs->log_accumulated_prob;
}
};

// A comparator for std::shared_ptr<TokenTreeNode>
// This is used in to sort the token tree nodes in descending order
struct SharedTokenTreeNodePtrGreater {
bool operator()(std::shared_ptr<TokenTreeNode> const &lhs,
std::shared_ptr<TokenTreeNode> const &rhs) const {
if (lhs->gumbel) {
assert(rhs->gumbel);
return lhs->gumbel_logit > rhs->gumbel_logit;
}
return lhs->log_accumulated_prob > rhs->log_accumulated_prob;
}
};

class TokenTree {
public:
std::list<std::list<shared_ptr<TokenTreeNode>>> tree_layers = {};
void add_layer() {
tree_layers.emplace_back();
}

void clear() {
tree_layers.clear();
}
};

std::ostream &operator<<(std::ostream &os, TokenTree const &token_tree);

struct Request {
enum Status {
PENDING = 101, // loading prompt
Expand All @@ -68,6 +132,8 @@ struct Request {
int batch_index = -1;
int ssm_cache_size = 0;
int llm_cache_size = 0;
double slo_ratio = 1.0;
double decode_latency_ms = 0.0;
int ssm_prefill_len = 0;
int llm_prefill_len = 0;

Expand Down Expand Up @@ -136,52 +202,37 @@ struct Request {
// 1. Prefilling phase
// 2. Committing phase after the target model verification
StreamingCacheInfo streaming_cache_info;
};

class TokenTreeNode {
public:
BatchConfig::TokenId id;
float log_accumulated_prob;
int parent_pos;
bool pruned = false;
bool gumbel = false;
float gumbel_logit = 0.0f;
std::priority_queue<std::shared_ptr<TokenTreeNode>,
std::vector<std::shared_ptr<TokenTreeNode>>,
SharedTokenTreeNodePtrLess>
token_tree_nodes_pq;

TokenTreeNode(BatchConfig::TokenId id,
float log_accumulated_prob,
int parent_pos,
bool gumbel = false,
float gumbel_logit = 0.0f)
: id(id), log_accumulated_prob(log_accumulated_prob),
parent_pos(parent_pos), gumbel(gumbel), gumbel_logit(gumbel_logit) {}
double get_length_weight();
void set_slo_ratio(double slo_ratio_);
double get_slo_ratio();
};

bool operator<(std::shared_ptr<TokenTreeNode> const &lhs,
std::shared_ptr<TokenTreeNode> const &rhs);

bool operator<=(std::shared_ptr<TokenTreeNode> const &lhs,
std::shared_ptr<TokenTreeNode> const &rhs);

// A comparator for std::shared_ptr<TokenTreeNode>
// This is used in to sort the token tree nodes in descending order
struct CompareSharedTokenTreeNodePtr {
bool operator()(std::shared_ptr<TokenTreeNode> const &lhs,
std::shared_ptr<TokenTreeNode> const &rhs) const {
if (lhs->gumbel) {
assert(rhs->gumbel);
return lhs->gumbel_logit < rhs->gumbel_logit;
// A comparator for std::pair<std::shared_ptr<TokenTreeNode>, RequestGuid>
// This is used to sort the token tree nodes in ascending order
struct SharedTokenTreeNodePtrRequestWeightedGreater {
bool operator()(
std::pair<std::shared_ptr<TokenTreeNode>, Request &> const &lhs,
std::pair<std::shared_ptr<TokenTreeNode>, Request &> const &rhs) const {
if (lhs.first->gumbel) {
assert(rhs.first->gumbel);
return lhs.first->gumbel_logit * lhs.second.get_length_weight() >
rhs.first->gumbel_logit * rhs.second.get_length_weight();
}
return lhs->log_accumulated_prob < rhs->log_accumulated_prob;
return lhs.first->log_accumulated_prob * lhs.second.get_length_weight() >
rhs.first->log_accumulated_prob * rhs.second.get_length_weight();
}
};

// A comparator for std::pair<std::shared_ptr<TokenTreeNode>, RequestGuid>
// This is used to sort the token tree nodes in ascending order
struct CompareSharedTokenTreeNodePtrRequestGuidPair {
bool operator()(std::pair<std::shared_ptr<TokenTreeNode>,
BatchConfig::RequestGuid> const &lhs,
std::pair<std::shared_ptr<TokenTreeNode>,
BatchConfig::RequestGuid> const &rhs) const {
struct SharedTokenTreeNodePtrRequestGreater {
bool operator()(
std::pair<std::shared_ptr<TokenTreeNode>, Request &> const &lhs,
std::pair<std::shared_ptr<TokenTreeNode>, Request &> const &rhs) const {
if (lhs.first->gumbel) {
assert(rhs.first->gumbel);
return lhs.first->gumbel_logit > rhs.first->gumbel_logit;
Expand All @@ -190,27 +241,6 @@ struct CompareSharedTokenTreeNodePtrRequestGuidPair {
}
};

class TokenTree {
public:
std::list<std::list<shared_ptr<TokenTreeNode>>> tree_layers = {};
// The numebr of tokens in the tree that are not pruned
int tree_size = 0;
// The numebr of tokens in the tree including the pruned ones

void add_layer() {
tree_layers.emplace_back();
}

void clear() {
tree_layers.clear();
tree_size = 0;
}

TokenTree() : tree_size(0) {}
};

std::ostream &operator<<(std::ostream &os, TokenTree const &token_tree);

class RequestManager {
public:
enum State {
Expand Down Expand Up @@ -260,7 +290,17 @@ class RequestManager {
int get_max_tree_width();
void set_max_tree_width(int max_tree_width);
void set_speculative_sampling(bool speculative_sampling);
void set_baseline_latency(double baseline_latency_ms);
double get_baseline_latency();
void set_ssm_spec_latency(double ssm_spec_latency_ms);
double get_ssm_spec_latency();
void set_llm_verify_latency(double llm_verify_latency_ms);
double get_llm_verify_latency();
void set_correction_factor(double correction_factor);
double get_correction_factor();
void set_streaming_cache(bool streaming_cache);
bool get_memory_occupancy();
void set_memory_occupancy(bool memory_occupancy);
int register_ssm_model(FFModel *model);
void register_tokenizer(ModelType model_type,
int bos_token_id,
Expand Down Expand Up @@ -330,13 +370,20 @@ class RequestManager {
int max_tree_depth;
int max_tree_width;
int k;
// Profile based latency
double baseline_latency_ms = 1000;
double ssm_spec_latency_ms = 50;
double llm_verify_latency_ms = 50;
double correction_factor = 1.05;

State request_manager_status;
BackgroundServerStatus background_server_status;
DecodingMode decoding_mode;
PrefillModel prefill_model;
bool speculative_sampling = false;
// specify if enable streaming cache for incremental decoding or draft model
bool streaming_cache = false;
bool memory_occupancy = false;

std::unique_ptr<Tokenizer> tokenizer_;
bool verbose;
Expand Down Expand Up @@ -364,14 +411,6 @@ class RequestManager {
int num_available_requests = 0;
int ssm_completed = true;

// This is a helper data structure to store help the pruning of the token
// trees across different requests.
// TODO: clear this in the first step of the speculation!
std::priority_queue<
std::pair<std::shared_ptr<TokenTreeNode>, RequestGuid>,
std::vector<std::pair<std::shared_ptr<TokenTreeNode>, RequestGuid>>,
CompareSharedTokenTreeNodePtrRequestGuidPair>
token_tree_node_pool;
// rm state
std::mutex rm_state_mutex;

Expand Down Expand Up @@ -455,8 +494,13 @@ class RequestManager {
void init_token_tree(RequestGuid guid);
void add_root_to_spec_token_tree(RequestGuid guid,
BatchConfig::TokenId token_id);
bool add_tokens_to_spec_token_tree(
void add_tokens_to_spec_token_tree(
InferenceResult const &ssm_inference_result);
void prune_token_tree();
void add_tokens_toward_slo(RequestGuid guid, int &budget);
void add_tokens_toward_memory_occupancy(int budget);
void add_tokens_toward_goodput(int budget);

/* ---------- Spec Decoding Helper Functions ---------- */
void renormalize(std::vector<std::pair<TokenId, float>> &D,
std::unordered_map<TokenId, float> &R,
Expand All @@ -471,5 +515,4 @@ class RequestManager {
// Profiling related functions
void reset_profiling_statistics();
};

}; // namespace FlexFlow
Loading
Loading