Skip to content

Commit

Permalink
backup
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Oct 4, 2024
1 parent 88d60ca commit e453237
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 26 deletions.
3 changes: 1 addition & 2 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ class BatchConfig {
request_guid = 0;
prompt_phase = false;
batch_config_request_id = -1;
peft_model_id = PEFTModelID::NO_ID;
peft_bwd = false;
optimizer_tasks = {true, false, false, false};
}
Expand All @@ -110,7 +109,7 @@ class BatchConfig {
bool prompt_phase = false;
RequestGuid request_guid;
// PEFT fields
PEFTModelID peft_model_id;
std::unordered_map<PEFTModelID, std::string> peft_adapters;
bool peft_bwd;
OptimizerTasks optimizer_tasks;
};
Expand Down
48 changes: 29 additions & 19 deletions include/flexflow/ops/lora_linear_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class LoraLinearConfig {
LoraLinearConfig const &rhs);
friend std::ostream &operator<<(std::ostream &os,
LoraLinearConfig const &llc);
void serialize_to_json_file(const std::string& filename) const {
std::string serialize_to_json_string(int indent=-1) const {
json j = {
{"cache_folder", cache_folder},
{"peft_model_id", peft_model_id},
Expand All @@ -147,30 +147,40 @@ class LoraLinearConfig {
{"optimizer_config", optimizer_config ? optimizer_config->toJson() : nullptr}
};

return j.dump(indent); // No indentation
}
void serialize_to_json_file(const std::string& filename) const {
std::string j = serialize_to_json_string(4);
std::ofstream file(filename);
file << j.dump(4); // Use 4 spaces for indentation
file << j;
}
// Deserialization method
static LoraLinearConfig deserialize_from_json_file(const std::string& filename) {
std::ifstream file(filename);
json j;
file >> j;
LoraLinearConfig metadata(
j["cache_folder"].get<std::string>(),
j["peft_model_id"].get<std::vector<int>>(),
j["rank"].get<std::string>(),
j["lora_alpha"].get<std::string>(),
j["lora_dropout"].get<std::string>(),
j["target_modules"].get<std::vector<std::string>>(),
j["trainable"].get<bool>(),
j["init_lora_weights"].get<bool>(),
j["base_model_name_or_path"].get<std::string>(),
j["precision"].get<std::string>()
static LoraLinearConfig deserialize_from_json_string(const std::string& json_string) {
json j = json::parse(json_string);
LoraLinearConfig config(
j["cache_folder"].get<std::string>(),
j["peft_model_id"].get<std::string>(),
j["trainable"].get<bool>(),
nullptr, // optimizer_config will be set later if present
j["init_lora_weights"].get<bool>(),
j["base_model_name_or_path"].get<std::string>(),
j["precision"].get<std::string>(),
j["rank"].get<int>(),
j["lora_alpha"].get<float>(),
j["lora_dropout"].get<float>(),
j["target_modules"].get<std::vector<std::string>>()
);
if (!j["optimizer_config"].is_null()) {
metadata.optimizer_config = LoraOptimizerConfig::fromJson(j["optimizer_config"]);
config.setOptimizer(LoraOptimizerConfig::fromJson(j["optimizer_config"]));
}
return metadata;
return config;
}
// Deserialization method
static LoraLinearConfig deserialize_from_json_file(const std::string& filename) {
std::ifstream file(filename);
std::string j;
file >> j;
return deserialize_from_json_string(j);
}

std::string cache_folder;
Expand Down
3 changes: 3 additions & 0 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ class RequestManager {
int eos_token_id,
std::string const &path);
void register_output_filepath(std::string const &);
void register_peft_model(FFModel *model, PEFTModelID peft_model_id);
LoraLinearConfig get_peft_config(PEFTModelID peft_model_id);
void initBitMask(BatchConfig::BitMask &bitmask, int initLength);
void appendPendingRequest(BatchConfig::BitMask &bitmask, int initLength);
void appendBitMask(BatchConfig::BitMask &bitmask,
Expand Down Expand Up @@ -289,6 +291,7 @@ class RequestManager {
int max_sequence_length;
Status request_manager_status;

std::unordered_map<PEFTModelID, LoraLinearConfig> peft_configs;
// peft benchmarking
bool enable_peft_finetuning = false;
static bool inference_finished;
Expand Down
56 changes: 51 additions & 5 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,46 @@ size_t RequestManager::get_num_ssms() {
return ssm_models.size();
}

void RequestManager::register_peft_config(PEFTModelID const &peft_model_id,
LoraLinearConfig const &peft_config) {
// check that peft_model_id is not already in use
assert(peft_configs.find(peft_model_id) == peft_configs.end() &&
"PEFT model ID already in use");
peft_configs[peft_model_id] = peft_config;
}

LoraLinearConfig const &RequestManager::get_peft_config(
PEFTModelID const &peft_model_id) {
assert(peft_configs.find(peft_model_id) != peft_configs.end() &&
"PEFT model ID not found");
return peft_configs[peft_model_id];
}

PEFTModelID *FFModel::register_peft_adapter(LoraLinearConfig const peft_config) {
assert(config.enable_peft &&
"Cannot add a LoRA layer if PEFT mode is not enabled");
if (peft_config.target_modules.size() == 0) {
printf("PEFT config does not contain any target module\n");
std::cout << peft_config << std::endl;
assert(false);
}
// go over base_layer_to_peft_layer and check that you can find at least one match
for (int i=0; i<peft_config.target_modules.size(); i++) {
bool found = false;
for (auto const &base_layer : peft_config.base_layer_to_peft_layer) {
if (base_layer.name != nullptr && strlen(base_layer.name) > 0 && std::string(base_layer.name).find(peft_config.target_modules[0]) != std::string::npos) {
found = true;
break;
}
}
assert(found && "Attempting to add LoRA to a LLM target module that does not exist or does not support LoRA");
}
PEFTModelID *peft_model_id = new PEFTModelID(peft_model_global_guid++);
RequestManager *rm = RequestManager::get_request_manager();
rm->register_peft_config(*peft_model_id, peft_config);
return peft_model_id;
}

RequestManager::RequestGuid
RequestManager::register_new_request(Request const &request_) {
const std::lock_guard<std::mutex> lock(request_queue_mutex);
Expand Down Expand Up @@ -730,8 +770,10 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens;
new_bc.requestsInfo[i].request_guid =
old_bc.requestsInfo[i].request_guid;
new_bc.requestsInfo[i].peft_model_id =
old_bc.requestsInfo[i].peft_model_id;
// new_bc.requestsInfo[i].peft_model_id =
// old_bc.requestsInfo[i].peft_model_id;
new_bc.requestsInfo[i].peft_adapters =
old_bc.requestsInfo[i].peft_adapters;
new_bc.requestsInfo[i].peft_bwd = old_bc.requestsInfo[i].peft_bwd;
new_bc.requestsInfo[i].max_length = old_bc.requestsInfo[i].max_length;
num_active_req++;
Expand Down Expand Up @@ -800,7 +842,10 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
std::min(get_max_tokens_per_batch() - new_bc.num_tokens,
(int)new_request.tokens.size());
new_bc.requestsInfo[i].max_length = new_request.max_length;
new_bc.requestsInfo[i].peft_model_id = new_request.peft_model_id;
// new_bc.requestsInfo[i].peft_model_id = new_request.peft_model_id;
if (new_request.peft_model_id != PEFTModelID::NO_ID) {
new_bc.requestsInfo[i].peft_adapters[new_request.peft_model_id] = get_peft_config(new_request.peft_model_id).serialize_to_json_string();
}
new_bc.requestsInfo[i].peft_bwd = false;
new_bc.request_completed[i] = false;
new_bc.requestsInfo[i].prompt_phase = true;
Expand Down Expand Up @@ -967,8 +1012,9 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
num_peft_tokens;
new_bc.requestsInfo[inference_batch_size].max_length = request.max_length;
new_bc.requestsInfo[inference_batch_size].request_guid = request.guid;
new_bc.requestsInfo[inference_batch_size].peft_model_id =
request.peft_model_id;
// new_bc.requestsInfo[inference_batch_size].peft_model_id =
// request.peft_model_id;
new_bc.requestsInfo[inference_batch_size].peft_adapters[request.peft_model_id] = get_peft_config(request.peft_model_id).serialize_to_json_string();
new_bc.requestsInfo[inference_batch_size].peft_bwd = true;
set_optimizer_tasks(
new_bc.requestsInfo[inference_batch_size].optimizer_tasks,
Expand Down

0 comments on commit e453237

Please sign in to comment.