Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Oct 5, 2024
1 parent e453237 commit 5c8c448
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 13 deletions.
2 changes: 1 addition & 1 deletion include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ class FFModel {
// PEFT Layers
// ========================================
// PEFTModelID *add_lora_layer(LoraLinearConfig const peft_config);
void add_lora_layers(std::vector<std::string> target_modules, int max_rank, int max_concurrent_adapters);
void add_lora_layers(std::vector<std::string> target_modules);
// ========================================
// Inference APIs
// ========================================
Expand Down
2 changes: 1 addition & 1 deletion include/flexflow/ops/lora_linear_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class LoraLinearConfig {
class LoraLinearParams {
public:
LayerID layer_guid;
OperatorType type;
// OperatorType type;
// std::unordered_map<PEFTModelID, LoraLinearConfig> peft_configs;
int max_rank;
int max_concurrent_adapters;
Expand Down
9 changes: 8 additions & 1 deletion include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ class RequestManager {
void register_output_filepath(std::string const &);
void register_peft_model(FFModel *model, PEFTModelID peft_model_id);
LoraLinearConfig get_peft_config(PEFTModelID peft_model_id);
void set_max_lora_rank(int max_lora_rank);
void set_max_concurrent_adapters(int max_concurrent_adapters);
int get_max_lora_rank();
int get_max_concurrent_adapters();
void initBitMask(BatchConfig::BitMask &bitmask, int initLength);
void appendPendingRequest(BatchConfig::BitMask &bitmask, int initLength);
void appendBitMask(BatchConfig::BitMask &bitmask,
Expand Down Expand Up @@ -290,8 +294,11 @@ class RequestManager {
int max_spec_tree_token_num;
int max_sequence_length;
Status request_manager_status;


// peft
std::unordered_map<PEFTModelID, LoraLinearConfig> peft_configs;
int max_lora_rank;
int max_concurrent_adapters;
// peft benchmarking
bool enable_peft_finetuning = false;
static bool inference_finished;
Expand Down
41 changes: 32 additions & 9 deletions src/ops/lora_linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,13 @@ bool check_lora_layer_match(Layer *potential_target,
return false;
}

void FFmodel::add_lora_layers(std::vector<std::string> target_modules, int max_rank, int max_concurrent_adapters) {
void FFmodel::add_lora_layers(std::vector<std::string> target_modules) {
assert(config.enable_peft && "Cannot add a LoRA layer if PEFT mode is not enabled");
assert(target_modules.size() > 0 && "LoRA target module name is empty");
assrt(max_rank > 1 && max_rank <= 32 && "Invalid max LoRA rank");
RequestManager *rm = RequestManager::get_request_manager();
int max_lora_rank = rm->get_max_lora_rank();
int max_concurrent_adapters = rm->get_max_concurrent_adapters();
assert(max_rank > 1 && max_rank <= 32 && "Invalid max LoRA rank");
assert(max_concurrent_adapters > 0 && "Invalid number of LoRA concurrent adapters");

for (std::string target_module_name : target_modules) {
Expand Down Expand Up @@ -1197,14 +1200,17 @@ bool LoraLinear::measure_operator_cost(Simulator *sim,
}

bool operator==(LoraLinearParams const &lhs, LoraLinearParams const &rhs) {
if (lhs.layer_guid == rhs.layer_guid && lhs.type == rhs.type &&
lhs.peft_configs.size() == rhs.peft_configs.size()) {
if (lhs.layer_guid == rhs.layer_guid && lhs.max_rank == rhs.max_rank &&
lhs.max_concurrent_adapters == rhs.max_concurrent_adapters &&
strcmp(lhs.name, rhs.name) == 0) {
#ifdef DEADCODE
for (auto const &kv : lhs.peft_configs) {
auto it = rhs.peft_configs.find(kv.first);
if (it == rhs.peft_configs.end() || !(it->second == kv.second)) {
return false;
}
}
#endif
return true;
}
return false;
Expand Down Expand Up @@ -1243,6 +1249,9 @@ void LoraLinear::serialize(Legion::Serializer &sez) const {
sez.serialize(this->layer_guid.id);
sez.serialize(this->layer_guid.transformer_layer_id);
sez.serialize(this->layer_guid.model_id);
sez.serialize(this->max_rank);
sez.serialize(this->max_concurrent_adapters);
#ifdef DEADCODE
sez.serialize(this->op_type);
sez.serialize(this->peft_configs.size());
for (auto const &kv : this->peft_configs) {
Expand Down Expand Up @@ -1285,6 +1294,7 @@ void LoraLinear::serialize(Legion::Serializer &sez) const {
}
}
}
#endif
sez.serialize(strlen(this->name));
sez.serialize(this->name, strlen(this->name));
}
Expand All @@ -1297,8 +1307,9 @@ Node LoraLinear::deserialize(FFModel &ff,
int num_inputs) {
assert(num_inputs == 2);
size_t id, transformer_layer_id, deserialized_model_id;
OperatorType op_type;
size_t num_pefts;
int max_rank, max_concurrent_adapters;
// OperatorType op_type;
// size_t num_pefts;
size_t name_len;
char name[MAX_OPNAME] = {0};

Expand All @@ -1307,6 +1318,9 @@ Node LoraLinear::deserialize(FFModel &ff,
dez.deserialize(id);
dez.deserialize(transformer_layer_id);
dez.deserialize(deserialized_model_id);
dez.deserialize(max_rank);
dez.deserialize(max_concurrent_adapters);
#ifdef DEADCODE
dez.deserialize(op_type);
dez.deserialize(num_pefts);
for (int i = 0; i < num_pefts; i++) {
Expand Down Expand Up @@ -1357,12 +1371,15 @@ Node LoraLinear::deserialize(FFModel &ff,
params.peft_configs.emplace(
std::make_pair(peft_model_id, *lora_linear_config));
}
#endif
dez.deserialize(name_len);
dez.deserialize(name, name_len);
LayerID layer_guid(id, transformer_layer_id, deserialized_model_id);

params.layer_guid = layer_guid;
params.type = op_type;
// params.type = op_type;
params.max_rank = max_rank;
params.max_concurrent_adapters = max_concurrent_adapters;
strcpy(params.name, name);
return ff.get_or_create_node<LoraLinear>({inputs[0], inputs[1]}, params);
}
Expand All @@ -1377,11 +1394,13 @@ Op *LoraLinear::materialize(FFModel &ff,
LoraLinearParams LoraLinear::get_params() const {
LoraLinearParams params;
params.layer_guid = this->layer_guid;
params.type = this->op_type;
params.max_rank = this->max_rank;
params.max_concurrent_adapters = this->max_concurrent_adapters;
// params.type = this->op_type;
if (strlen(this->name) < MAX_OPNAME) {
strcpy(params.name, this->name);
}
params.peft_configs = this->peft_configs;
// params.peft_configs = this->peft_configs;
return params;
}

Expand All @@ -1400,6 +1419,9 @@ size_t hash<FlexFlow::LoraLinearParams>::operator()(
hash_combine(key, params.layer_guid.id);
hash_combine(key, params.layer_guid.transformer_layer_id);
hash_combine(key, params.layer_guid.model_id);
hash_combine(key, params.max_rank);
hash_combine(key, params.max_concurrent_adapters);
#ifdef DEADCODE
for (auto const &kv : params.peft_configs) {
hash_combine(key, kv.first.id);
hash_combine(key, kv.second.rank);
Expand All @@ -1411,6 +1433,7 @@ size_t hash<FlexFlow::LoraLinearParams>::operator()(
hash_combine(key, kv.second.target_modules);
hash_combine(key, kv.second.init_lora_weights);
}
#endif
return key;
}
}; // namespace std
31 changes: 30 additions & 1 deletion src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,20 @@ LoraLinearConfig const &RequestManager::get_peft_config(
return peft_configs[peft_model_id];
}

void RequestManager::set_max_lora_rank(int max_lora_rank_) {
max_lora_rank = max_lora_rank_;
}

void RequestManager::set_max_concurrent_adapters(int max_concurrent_adapters_) {
max_concurrent_adapters = max_concurrent_adapters_;
}

int RequestManager::get_max_lora_rank() { return max_lora_rank; }

int RequestManager::get_max_concurrent_adapters() {
return max_concurrent_adapters;
}

PEFTModelID *FFModel::register_peft_adapter(LoraLinearConfig const peft_config) {
assert(config.enable_peft &&
"Cannot add a LoRA layer if PEFT mode is not enabled");
Expand Down Expand Up @@ -679,6 +693,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
int inference_batch_size =
BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning;

int num_concurrent_adapters = 0;

// Step 2: prepare the next batch for existing inference requests
BatchConfig new_bc;
for (int i = 0; i < inference_batch_size; i++) {
Expand Down Expand Up @@ -774,6 +790,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
// old_bc.requestsInfo[i].peft_model_id;
new_bc.requestsInfo[i].peft_adapters =
old_bc.requestsInfo[i].peft_adapters;
num_concurrent_adapters += new_bc.requestsInfo[i].peft_adapters.size();
new_bc.requestsInfo[i].peft_bwd = old_bc.requestsInfo[i].peft_bwd;
new_bc.requestsInfo[i].max_length = old_bc.requestsInfo[i].max_length;
num_active_req++;
Expand Down Expand Up @@ -825,13 +842,22 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
}
new_bc.num_generation_tokens = num_generation_tokens;

assert(num_concurrent_adapters <= get_max_concurrent_adapters() &&
"Number of concurrent adapters exceeded the limit");

// Step 3: add new inference requests to the next batch if there is space
for (int i = 0; i < inference_batch_size; i++) {
if (new_bc.request_completed[i]) {
if (!pending_infr_request_queue.empty() &&
new_bc.num_tokens < get_max_tokens_per_batch()) {
Request new_request = pending_infr_request_queue.front();
assert(new_request.req_type == RequestType::REQ_INFERENCE);

// if the request has peft adapters and we are at capacity, don't add it yet
if (new_request.peft_model_id != PEFTModelID::NO_ID && num_concurrent_adapters == get_max_concurrent_adapters()) {
break;
}

pending_infr_request_queue.pop();
// all_requests[new_request.guid] = new_request;

Expand Down Expand Up @@ -1000,7 +1026,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
int num_peft_label_tokens = request.dataset[dataset_entry].second.size();
assert(num_peft_label_tokens == 0);

if (num_peft_tokens > 0) {
if (num_peft_tokens > 0 && num_concurrent_adapters < get_max_concurrent_adapters()) {
assert(new_bc.request_completed[inference_batch_size]);
// request info
new_bc.request_completed[inference_batch_size] = false;
Expand Down Expand Up @@ -1033,8 +1059,11 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
new_bc.num_tokens++;
new_bc.num_peft_tokens++;
}
num_concurrent_adapters +=1;
}
}
assert(num_concurrent_adapters <= get_max_concurrent_adapters() &&
"Number of concurrent adapters exceeded the limit");
return new_bc;
}

Expand Down

0 comments on commit 5c8c448

Please sign in to comment.