diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 5ac91d5b81..d1dbe72d7c 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -846,7 +846,7 @@ class FFModel { // PEFT Layers // ======================================== // PEFTModelID *add_lora_layer(LoraLinearConfig const peft_config); - void add_lora_layers(std::vector target_modules, int max_rank, int max_concurrent_adapters); + void add_lora_layers(std::vector target_modules); // ======================================== // Inference APIs // ======================================== diff --git a/include/flexflow/ops/lora_linear_params.h b/include/flexflow/ops/lora_linear_params.h index 84e76c4cc7..c5a327459f 100644 --- a/include/flexflow/ops/lora_linear_params.h +++ b/include/flexflow/ops/lora_linear_params.h @@ -208,7 +208,7 @@ class LoraLinearConfig { class LoraLinearParams { public: LayerID layer_guid; - OperatorType type; + // OperatorType type; // std::unordered_map peft_configs; int max_rank; int max_concurrent_adapters; diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index bff0e4d90c..fcb09f15ed 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -151,6 +151,10 @@ class RequestManager { void register_output_filepath(std::string const &); void register_peft_model(FFModel *model, PEFTModelID peft_model_id); LoraLinearConfig get_peft_config(PEFTModelID peft_model_id); + void set_max_lora_rank(int max_lora_rank); + void set_max_concurrent_adapters(int max_concurrent_adapters); + int get_max_lora_rank(); + int get_max_concurrent_adapters(); void initBitMask(BatchConfig::BitMask &bitmask, int initLength); void appendPendingRequest(BatchConfig::BitMask &bitmask, int initLength); void appendBitMask(BatchConfig::BitMask &bitmask, @@ -290,8 +294,11 @@ class RequestManager { int max_spec_tree_token_num; int max_sequence_length; Status request_manager_status; - + + // peft std::unordered_map peft_configs; + int max_lora_rank; + int max_concurrent_adapters; // peft benchmarking bool enable_peft_finetuning = false; static bool inference_finished; diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc index f4c1ba9c35..1ba11ed75e 100644 --- a/src/ops/lora_linear.cc +++ b/src/ops/lora_linear.cc @@ -51,10 +51,13 @@ bool check_lora_layer_match(Layer *potential_target, return false; } -void FFmodel::add_lora_layers(std::vector target_modules, int max_rank, int max_concurrent_adapters) { +void FFmodel::add_lora_layers(std::vector target_modules) { assert(config.enable_peft && "Cannot add a LoRA layer if PEFT mode is not enabled"); assert(target_modules.size() > 0 && "LoRA target module name is empty"); - assrt(max_rank > 1 && max_rank <= 32 && "Invalid max LoRA rank"); + RequestManager *rm = RequestManager::get_request_manager(); + int max_lora_rank = rm->get_max_lora_rank(); + int max_concurrent_adapters = rm->get_max_concurrent_adapters(); + assert(max_rank > 1 && max_rank <= 32 && "Invalid max LoRA rank"); assert(max_concurrent_adapters > 0 && "Invalid number of LoRA concurrent adapters"); for (std::string target_module_name : target_modules) { @@ -1197,14 +1200,17 @@ bool LoraLinear::measure_operator_cost(Simulator *sim, } bool operator==(LoraLinearParams const &lhs, LoraLinearParams const &rhs) { - if (lhs.layer_guid == rhs.layer_guid && lhs.type == rhs.type && - lhs.peft_configs.size() == rhs.peft_configs.size()) { + if (lhs.layer_guid == rhs.layer_guid && lhs.max_rank == rhs.max_rank && + lhs.max_concurrent_adapters == rhs.max_concurrent_adapters && + strcmp(lhs.name, rhs.name) == 0) { +#ifdef DEADCODE for (auto const &kv : lhs.peft_configs) { auto it = rhs.peft_configs.find(kv.first); if (it == rhs.peft_configs.end() || !(it->second == kv.second)) { return false; } } +#endif return true; } return false; @@ -1243,6 +1249,9 @@ void LoraLinear::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); sez.serialize(this->layer_guid.model_id); + sez.serialize(this->max_rank); + sez.serialize(this->max_concurrent_adapters); +#ifdef DEADCODE sez.serialize(this->op_type); sez.serialize(this->peft_configs.size()); for (auto const &kv : this->peft_configs) { @@ -1285,6 +1294,7 @@ void LoraLinear::serialize(Legion::Serializer &sez) const { } } } +#endif sez.serialize(strlen(this->name)); sez.serialize(this->name, strlen(this->name)); } @@ -1297,8 +1307,9 @@ Node LoraLinear::deserialize(FFModel &ff, int num_inputs) { assert(num_inputs == 2); size_t id, transformer_layer_id, deserialized_model_id; - OperatorType op_type; - size_t num_pefts; + int max_rank, max_concurrent_adapters; + // OperatorType op_type; + // size_t num_pefts; size_t name_len; char name[MAX_OPNAME] = {0}; @@ -1307,6 +1318,9 @@ Node LoraLinear::deserialize(FFModel &ff, dez.deserialize(id); dez.deserialize(transformer_layer_id); dez.deserialize(deserialized_model_id); + dez.deserialize(max_rank); + dez.deserialize(max_concurrent_adapters); +#ifdef DEADCODE dez.deserialize(op_type); dez.deserialize(num_pefts); for (int i = 0; i < num_pefts; i++) { @@ -1357,12 +1371,15 @@ Node LoraLinear::deserialize(FFModel &ff, params.peft_configs.emplace( std::make_pair(peft_model_id, *lora_linear_config)); } +#endif dez.deserialize(name_len); dez.deserialize(name, name_len); LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); params.layer_guid = layer_guid; - params.type = op_type; + // params.type = op_type; + params.max_rank = max_rank; + params.max_concurrent_adapters = max_concurrent_adapters; strcpy(params.name, name); return ff.get_or_create_node({inputs[0], inputs[1]}, params); } @@ -1377,11 +1394,13 @@ Op *LoraLinear::materialize(FFModel &ff, LoraLinearParams LoraLinear::get_params() const { LoraLinearParams params; params.layer_guid = this->layer_guid; - params.type = this->op_type; + params.max_rank = this->max_rank; + params.max_concurrent_adapters = this->max_concurrent_adapters; + // params.type = this->op_type; if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } - params.peft_configs = this->peft_configs; + // params.peft_configs = this->peft_configs; return params; } @@ -1400,6 +1419,9 @@ size_t hash::operator()( hash_combine(key, params.layer_guid.id); hash_combine(key, params.layer_guid.transformer_layer_id); hash_combine(key, params.layer_guid.model_id); + hash_combine(key, params.max_rank); + hash_combine(key, params.max_concurrent_adapters); +#ifdef DEADCODE for (auto const &kv : params.peft_configs) { hash_combine(key, kv.first.id); hash_combine(key, kv.second.rank); @@ -1411,6 +1433,7 @@ size_t hash::operator()( hash_combine(key, kv.second.target_modules); hash_combine(key, kv.second.init_lora_weights); } +#endif return key; } }; // namespace std diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 5e9a724d3f..79fcdfdcfe 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -270,6 +270,20 @@ LoraLinearConfig const &RequestManager::get_peft_config( return peft_configs[peft_model_id]; } +void RequestManager::set_max_lora_rank(int max_lora_rank_) { + max_lora_rank = max_lora_rank_; +} + +void RequestManager::set_max_concurrent_adapters(int max_concurrent_adapters_) { + max_concurrent_adapters = max_concurrent_adapters_; +} + +int RequestManager::get_max_lora_rank() { return max_lora_rank; } + +int RequestManager::get_max_concurrent_adapters() { + return max_concurrent_adapters; +} + PEFTModelID *FFModel::register_peft_adapter(LoraLinearConfig const peft_config) { assert(config.enable_peft && "Cannot add a LoRA layer if PEFT mode is not enabled"); @@ -679,6 +693,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, int inference_batch_size = BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning; + int num_concurrent_adapters = 0; + // Step 2: prepare the next batch for existing inference requests BatchConfig new_bc; for (int i = 0; i < inference_batch_size; i++) { @@ -774,6 +790,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, // old_bc.requestsInfo[i].peft_model_id; new_bc.requestsInfo[i].peft_adapters = old_bc.requestsInfo[i].peft_adapters; + num_concurrent_adapters += new_bc.requestsInfo[i].peft_adapters.size(); new_bc.requestsInfo[i].peft_bwd = old_bc.requestsInfo[i].peft_bwd; new_bc.requestsInfo[i].max_length = old_bc.requestsInfo[i].max_length; num_active_req++; @@ -825,6 +842,9 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, } new_bc.num_generation_tokens = num_generation_tokens; + assert(num_concurrent_adapters <= get_max_concurrent_adapters() && + "Number of concurrent adapters exceeded the limit"); + // Step 3: add new inference requests to the next batch if there is space for (int i = 0; i < inference_batch_size; i++) { if (new_bc.request_completed[i]) { @@ -832,6 +852,12 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, new_bc.num_tokens < get_max_tokens_per_batch()) { Request new_request = pending_infr_request_queue.front(); assert(new_request.req_type == RequestType::REQ_INFERENCE); + + // if the request has peft adapters and we are at capacity, don't add it yet + if (new_request.peft_model_id != PEFTModelID::NO_ID && num_concurrent_adapters == get_max_concurrent_adapters()) { + break; + } + pending_infr_request_queue.pop(); // all_requests[new_request.guid] = new_request; @@ -1000,7 +1026,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, int num_peft_label_tokens = request.dataset[dataset_entry].second.size(); assert(num_peft_label_tokens == 0); - if (num_peft_tokens > 0) { + if (num_peft_tokens > 0 && num_concurrent_adapters < get_max_concurrent_adapters()) { assert(new_bc.request_completed[inference_batch_size]); // request info new_bc.request_completed[inference_batch_size] = false; @@ -1033,8 +1059,11 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, new_bc.num_tokens++; new_bc.num_peft_tokens++; } + num_concurrent_adapters +=1; } } + assert(num_concurrent_adapters <= get_max_concurrent_adapters() && + "Number of concurrent adapters exceeded the limit"); return new_bc; }