diff --git a/include/flexflow/flexflow_c.h b/include/flexflow/flexflow_c.h index b88cddcc18..e805dc2ace 100644 --- a/include/flexflow/flexflow_c.h +++ b/include/flexflow/flexflow_c.h @@ -1014,6 +1014,12 @@ void flexflow_request_manager_start_background_server( void flexflow_request_manager_terminate_background_server( flexflow_request_manager_t handle_); +void flexflow_request_manager_save_peft_weights( + flexflow_request_manager_t handle_, + flexflow_model_t model_handle_, + flexflow_peft_model_id_t peft_model_id_, + char const *destination_folder); + // ----------------------------------------------------------------------- // InferenceManager // ----------------------------------------------------------------------- diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 099e2209e4..5f8c3a2de6 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -129,7 +129,7 @@ enum TaskIDs { LINEAR_BWD2_TASK_ID, LINEAR_UPD_TASK_ID, LORA_LINEAR_INIT_TASK_ID, - LORA_LINEAR_REG_TASK_ID, + LORA_LINEAR_SAVE_WEIGHTS_TASK_ID, LORA_LINEAR_INF_TASK_ID, LORA_LINEAR_PEFT_BWD_TASK_ID, FLAT_INIT_TASK_ID, diff --git a/include/flexflow/ops/lora_linear.h b/include/flexflow/ops/lora_linear.h index 9e83c3f90e..48d130a230 100644 --- a/include/flexflow/ops/lora_linear.h +++ b/include/flexflow/ops/lora_linear.h @@ -41,6 +41,13 @@ class LoraLinear : public Op { MachineView const *mv = nullptr) override; void forward(FFModel const &) override; void backward(FFModel const &) override; + void save_peft_weights(FFModel const &ff, + PEFTModelID const &model_id, + int rank, + std::string const &destination_folder, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv = nullptr); Legion::FutureMap inference(FFModel const &, BatchConfigFuture const &, std::vector const &, @@ -69,6 +76,11 @@ class LoraLinear : public Op { std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); + static void + save_peft_weights_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); static void forward_task(Legion::Task const *task, std::vector const ®ions, Legion::Context ctx, diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index 889204a1a1..9d19485b25 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -36,6 +36,9 @@ class InferenceManager { static InferenceManager *get_inference_manager(); void compile_model_and_allocate_buffer(FFModel *model); void init_operators_inference(FFModel *model); + void save_peft_weights(FFModel *model, + PEFTModelID const &model_id, + std::string const &destination_folder); Legion::FutureMap inference(FFModel *model, int index, BatchConfig const &bc); Legion::FutureMap inference(FFModel *model, int index, BatchConfigFuture const &bc); @@ -161,6 +164,10 @@ class RequestManager { FFModel *get_ssm_model(int model_id); + void save_peft_weights(FFModel *model, + PEFTModelID const &model_id, + std::string const &destination_folder); + void serve_incr_decoding(FFModel *model); void serve_spec_infer(FFModel *model); GenerationResult get_generation_result(RequestGuid const &guid); diff --git a/inference/peft/peft.cc b/inference/peft/peft.cc index 09131c625a..d97e20f6ff 100644 --- a/inference/peft/peft.cc +++ b/inference/peft/peft.cc @@ -372,6 +372,10 @@ void FlexFlow::top_level_task(Task const *task, future.get_void_result(); } + rm->save_peft_weights(&model, + *peft_model_id, + std::string("/root/.cache/flexflow/finetuned_weights")); + if (peft_model_id != nullptr) { free(peft_model_id); } diff --git a/inference/python/ff_peft.py b/inference/python/ff_peft.py index b367272aca..cabb9a62e1 100644 --- a/inference/python/ff_peft.py +++ b/inference/python/ff_peft.py @@ -25,8 +25,22 @@ def get_configs(): type=str, default="", ) - args = parser.parse_args() + parser.add_argument( + "--publish-peft-with-id", + help="The Hugging Face model ID to upload the trained model with", + type=str, + default="" + ) + args = parser.parse_args() + publish_peft_with_id = args.publish_peft_with_id + if len(publish_peft_with_id) == 0: + print( + "Please pass a --publish-peft-with-id if you want to upload the trained model" + ) + else: + print(f"The trained model will be uploaded with id: {publish_peft_with_id}") + # Load configs from JSON file (if specified) if len(args.config_file) > 0: if not os.path.isfile(args.config_file): @@ -67,7 +81,7 @@ def get_configs(): "inference_peft_model_id": "goliaro/llama-160m-lora", "finetuning_peft_model_id": "goliaro/llama-160m-lora", # optional parameters - "cache_path": "", + "cache_path": "~/.cache/flexflow", "refresh_cache": False, "full_precision": True, "prompt": "", @@ -75,10 +89,11 @@ def get_configs(): os.path.dirname(os.path.abspath(__file__)), "../prompt/peft_dataset.json", ), - "output_file": "", + "output_file": "" } # Merge dictionaries ff_init_configs.update(model_configs) + ff_init_configs["publish_peft_with_id"] = publish_peft_with_id return ff_init_configs @@ -98,7 +113,7 @@ def main(): data_type=ff_data_type, cache_path=configs.cache_path, refresh_cache=configs.refresh_cache, - output_file=configs.output_file, + output_file=configs.output_file ) # Add inference and/or finetuning lora lora_inference_config = None @@ -146,6 +161,8 @@ def main(): ) llm.start_server() + + print(f"LLM model class is: {llm.model_class}") requests = [] # Serving @@ -173,9 +190,17 @@ def main(): requests.append(finetuning_request) llm.generate(requests) - + llm.stop_server() - + + # upload the model back to huggingface after finetuning + # the model format would be converted from flexflow format back to huggingface format + if len(configs.publish_peft_with_id) > 0: + print( + f"Done training! Uploading the model to HF hub with id: {configs.publish_peft_with_id}..." + ) + llm.upload_peft_model(configs.publish_peft_with_id, private=True) + if __name__ == "__main__": print("flexflow PEFT example") diff --git a/inference/python/peft_metrics.py b/inference/python/peft_metrics.py new file mode 100644 index 0000000000..2d6d969b01 --- /dev/null +++ b/inference/python/peft_metrics.py @@ -0,0 +1,273 @@ +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import flexflow.serve as ff +import argparse, json, os +from types import SimpleNamespace +import time +import subprocess +import psutil +import time +import json + + +def get_gpu_utilization(): + try: + result = subprocess.run(['nvidia-smi', '--query-gpu=utilization.gpu,memory.used', '--format=csv,noheader,nounits'], stdout=subprocess.PIPE) + output = result.stdout.decode('utf-8').strip() + lines = output.split('\n') + + total_gpu_utilization = 0.0 + total_memory_used = 0.0 + num_gpus = len(lines) + + for line in lines: + try: + gpu_utilization, memory_used = line.split(', ') + total_gpu_utilization += float(gpu_utilization) + total_memory_used += float(memory_used) + except ValueError: + print("Error parsing line:", line) + num_gpus -= 1 # Adjust num_gpus in case of parsing failure + + # Handle division by zero if no GPUs are found or parsed successfully + if num_gpus > 0: + avg_gpu_utilization = total_gpu_utilization / num_gpus + avg_memory_used = total_memory_used / num_gpus + else: + avg_gpu_utilization = 0.0 + avg_memory_used = 0.0 + + + # print(f"GPU Utilization: {avg_gpu_utilization}%") + # print(f"Memory Used: {avg_memory_used} MiB") + + return avg_gpu_utilization, avg_memory_used + except Exception as e: + print(f"Failed to get GPU utilization: {e}") + return 0, 0 + + + +def get_cpu_utilization(): + # Gets the system-wide CPU utilization + return psutil.cpu_percent(interval=1) + +def get_memory_usage(): + # Gets the system-wide memory usage + memory_info = psutil.virtual_memory() + return memory_info.used / (1024 * 1024) # Convert to MB + +def monitor_resources(start_time, interval=5, duration=60): + """ + Monitors and collects resource usage metrics over a specified duration and interval. + + :param start_time: The time when the monitoring started, to calculate total duration. + :param interval: Time in seconds between each metric collection. + :param duration: Total duration to monitor resources. + :return: A dictionary containing the collected metrics. + """ + metrics = { + 'max_gpu_utilization': 0, + 'max_memory_usage_gpu': 0, + 'cpu_utilization': [], + 'peak_memory_usage_system': 0, + } + + while True: + current_time = time.time() + if current_time - start_time > duration: + break + + gpu_utilization, memory_usage_gpu = get_gpu_utilization() + cpu_utilization = get_cpu_utilization() + memory_usage_system = get_memory_usage() + + metrics['max_gpu_utilization'] = max(metrics['max_gpu_utilization'], gpu_utilization) + metrics['max_memory_usage_gpu'] = max(metrics['max_memory_usage_gpu'], memory_usage_gpu) + metrics['cpu_utilization'].append(cpu_utilization) + metrics['peak_memory_usage_system'] = max(metrics['peak_memory_usage_system'], memory_usage_system) + + time.sleep(interval) + + return metrics + +def get_configs(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-config-file", + help="The path to a JSON file with the configs. If omitted, a sample model and configs will be used instead.", + type=str, + default="", + ) + parser.add_argument( + "--publish-peft-with-id", + help="The Hugging Face model ID to upload the trained model with", + type=str, + default="" + ) + + args = parser.parse_args() + publish_peft_with_id = args.publish_peft_with_id + if len(publish_peft_with_id) == 0: + print( + "Please pass a --publish-peft-with-id if you want to upload the trained model" + ) + else: + print(f"The trained model will be uploaded with id: {publish_peft_with_id}") + + # Load configs from JSON file (if specified) + if len(args.config_file) > 0: + if not os.path.isfile(args.config_file): + raise FileNotFoundError(f"Config file {args.config_file} not found.") + try: + with open(args.config_file) as f: + return json.load(f) + except json.JSONDecodeError as e: + print("JSON format error:") + print(e) + else: + # Define sample configs + ff_init_configs = { + # required parameters + "num_gpus": 1, + "memory_per_gpu": 8192, + "zero_copy_memory_per_node": 12000, + # optional parameters + "num_cpus": 4, + "legion_utility_processors": 4, + "data_parallelism_degree": 1, + "tensor_parallelism_degree": 1, + "pipeline_parallelism_degree": 1, + "offload": False, + "offload_reserve_space_size": 8 * 1024, # 8GB + "use_4bit_quantization": False, + "use_8bit_quantization": False, + "enable_peft": True, + "peft_activation_reserve_space_size": 1024, # 1GB + "peft_weight_reserve_space_size": 1024, # 1GB + "profiling": False, + "inference_debugging": True, + "fusion": True, + } + model_configs = { + # required parameters + "base_model": "JackFram/llama-160m", + "peft_model_ids": [ + "goliaro/llama-160m-lora-full", + ], + # optional parameters + "cache_path": "~/.cache/flexflow", + "refresh_cache": False, + "full_precision": False, + "prompt": "", + "finetuning_dataset": os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../prompt/peft.json" + # peft.json is a sample dataset for finetuning, should contain a list of strings + ), + "output_file": "" + } + # Merge dictionaries + ff_init_configs.update(model_configs) + ff_init_configs["publish_peft_with_id"] = publish_peft_with_id + return ff_init_configs + + +def main(): + start_time = time.time() + configs_dict = get_configs() + configs = SimpleNamespace(**configs_dict) + + # Initialize the FlexFlow runtime. ff.init() takes a dictionary or the path to a JSON file with the configs + ff.init(configs_dict) + + # Create the FlexFlow LLM + ff_data_type = ( + ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF + ) + llm = ff.LLM( + configs.base_model, + data_type=ff_data_type, + cache_path=configs.cache_path, + refresh_cache=configs.refresh_cache, + output_file=configs.output_file + ) + for peft_model_id in configs.peft_model_ids: + llm.add_peft(peft_model_id) + + # Compile the LLM for inference and load the weights into memory + generation_config = ff.GenerationConfig( + do_sample=False, temperature=0.9, topp=0.8, topk=1 + ) + llm.compile( + generation_config, + max_requests_per_batch=1, + max_seq_length=256, + max_tokens_per_batch=64, + ) + + resource_metrics = monitor_resources(start_time, interval=5, duration=360) + + llm.start_server() + + print(f"LLM model class is: {llm.model_class}") + + requests = [] + # Serving + if len(configs.prompt) > 0: + prompts = [s for s in json.load(open(configs.prompt))] + inference_requests = [ + ff.Request( + ff.RequestType.REQ_INFERENCE, prompt=prompt, max_sequence_length=128 + ) + for prompt in prompts + ] + requests += inference_requests + # Finetuning + if len(configs.finetuning_dataset) > 0: + for peft_model_id in configs.peft_model_ids: + finetuning_request = ff.Request( + ff.RequestType.REQ_FINETUNING, + max_sequence_length=128, + peft_model_id=llm.get_ff_peft_id(peft_model_id), + dataset_filepath=configs.finetuning_dataset, + ) + requests.append(finetuning_request) + + # use the (finetuned) llm to generate some responses + llm.generate(requests) + + # After finishing the main workload, print the collected metrics. + avg_cpu_utilization = sum(resource_metrics['cpu_utilization']) / len(resource_metrics['cpu_utilization']) + print(f"Max GPU Utilization: {resource_metrics['max_gpu_utilization']}%") + print(f"Max GPU Memory Usage: {resource_metrics['max_memory_usage_gpu']} MiB") + print(f"Average CPU Utilization: {avg_cpu_utilization}%") + print(f"Peak System Memory Usage: {resource_metrics['peak_memory_usage_system']} MiB") + + + llm.stop_server() + + # upload the model back to huggingface after finetuning + # the model format would be converted from flexflow format back to huggingface format + if len(configs.publish_peft_with_id) > 0: + print( + f"Done training! Uploading the model to HF hub with id: {configs.publish_peft_with_id}..." + ) + llm.upload_peft_model(configs.publish_peft_with_id, private=True) + + +if __name__ == "__main__": + print("flexflow PEFT example") + + main() \ No newline at end of file diff --git a/inference/utils/download_peft_model.py b/inference/utils/download_peft_model.py index 38dd577574..82239cf221 100644 --- a/inference/utils/download_peft_model.py +++ b/inference/utils/download_peft_model.py @@ -6,7 +6,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--base_model_name", type=str, help="Name of the model to download" + "--base_model_name", type=str, required=True, help="Name of the model to download" ) parser.add_argument( "peft_model_ids", diff --git a/inference/utils/download_upload_peft.py b/inference/utils/download_upload_peft.py new file mode 100644 index 0000000000..27dd1e5607 --- /dev/null +++ b/inference/utils/download_upload_peft.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +import argparse +from huggingface_hub import HfApi, HfFolder +import flexflow.serve as ff + +def parse_args(): + parser = argparse.ArgumentParser(description="Download a PEFT model with FlexFlow, process it, and upload it to the Hugging Face Hub.") + parser.add_argument("peft_model_id", type=str, help="Original Hugging Face PEFT model ID to download and process (e.g., 'username/peft-model').") + parser.add_argument("--new-model-id", type=str, required=True, help="New Hugging Face Hub model ID for upload (e.g., 'your_username/new-peft-model-name').") + parser.add_argument("--cache-folder", type=str, default="./peft_model_cache", help="Folder to use to store and process the PEFT model(s) assets in FlexFlow format.") + parser.add_argument("--private", action="store_true", help="Whether to upload the processed PEFT model as a private model on Hugging Face Hub.") + parser.add_argument("--refresh-cache", action="store_true", help="Use this flag to force the refresh of the PEFT model(s) weights/cache.") + parser.add_argument("--full-precision", action="store_true", help="Download the full precision version of the weights for the PEFT model.") + return parser.parse_args() + + +def main(): + model_name = "tiiuae/falcon-7b" + new_model_id = "your_username/new-model-name" + cache_folder = "~/.cache/flexflow" + private = True + refresh_cache = False + full_precision = True + + data_type = ff.DataType.DT_FLOAT if full_precision else ff.DataType.DT_HALF + print(f"Downloading and processing peft model: {peft_model_id}") + peft = ff.PEFT( + peft_model_id, + data_type=data_type, + cache_path=args.cache_folder, + refresh_cache=args.refresh_cache, + ) + peft.download_hf_weights_if_needed() + peft.download_hf_config() + + print(f"Uploading processed model to Hugging Face Hub: {peft_model_id}") + peft.upload_hf_model(peft_model_id, cache_folder, private=private) + print("Upload completed successfully.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/inference/utils/upload_hf_model.py b/inference/utils/upload_hf_model.py new file mode 100644 index 0000000000..59e4573461 --- /dev/null +++ b/inference/utils/upload_hf_model.py @@ -0,0 +1,55 @@ + +# this script is for testing downloading a model from huggingface and uploading it back to huggingface +# after the model is downloaded it will be transformed into flexflow format +# before uploading it back to huggingface, we need to convert it back to huggingface format +# which is done by calling llm.upload_hf_model() + +#!/usr/bin/env python +import argparse, os +import flexflow.serve as ff +import warnings + +warnings.filterwarnings("ignore") + +def parse_args(): + parser = argparse.ArgumentParser( + description="Download a model with FlexFlow, process it, and upload it to the Hugging Face Hub." + ) + parser.add_argument( + "model_name", + type=str, + help="Original Hugging Face model ID to download and process (e.g., 'facebook/opt-125m')." + ) + parser.add_argument( + "--new-model-id", + type=str, + required=True, + help="New Hugging Face Hub model ID for upload (e.g., 'your_username/new-model-name')." + ) + parser.add_argument( + "--cache-folder", + type=str, + help="Folder to use to store the model(s) assets in FlexFlow format", + default=os.environ.get("FF_CACHE_PATH", ""), + ) + parser.add_argument("--private", action="store_true", help="Whether to upload the processed model as a private model on Hugging Face Hub.") + parser.add_argument("--full-precision", action="store_true", help="Download the full precision version of the weights.") + return parser.parse_args() + + +def main(): + args = parse_args() + data_type = ff.DataType.DT_FLOAT if args.full_precision else ff.DataType.DT_HALF + print(f"Downloading and processing model: {args.model_name}") + llm = ff.LLM( + model_name=args.model_name, + data_type=data_type, + cache_path=args.cache_folder, + refresh_cache=False, + ) + print(f"Uploading processed model to Hugging Face Hub: {args.new_model_id}") + llm.upload_hf_model(args.new_model_id, private=args.private) + print("Upload completed successfully.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index 0f41b5235c..3e12568eb9 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -1630,6 +1630,12 @@ def start_server(self, model): def stop_server(self): return ffc().flexflow_request_manager_terminate_background_server(self.handle) + + def save_peft_weights(self, model, peft_model_id, destination_folder): + c_destination_folder = get_c_name(destination_folder) + return ffc().flexflow_request_manager_save_peft_weights( + self.handle, model.handle, peft_model_id.handle, c_destination_folder + ) # ----------------------------------------------------------------------- diff --git a/python/flexflow/serve/__init__.py b/python/flexflow/serve/__init__.py index fd29080a6a..6c0296768a 100644 --- a/python/flexflow/serve/__init__.py +++ b/python/flexflow/serve/__init__.py @@ -15,16 +15,7 @@ from typing import Optional from ..type import * from flexflow.core import * -from .serve import ( - LLM, - SSM, - GenerationConfig, - GenerationResult, - LoraLinearConfig, - PEFTModelID, - Request, - RequestType, -) +from .serve import LLM, SSM def __check_positive_int(configs_dict: dict, key: str): diff --git a/python/flexflow/serve/models/base.py b/python/flexflow/serve/models/base.py index 17bb894250..b38faedc3e 100644 --- a/python/flexflow/serve/models/base.py +++ b/python/flexflow/serve/models/base.py @@ -32,8 +32,11 @@ def __init__( def build_model(self): assert False, "Not implemented yet" - def convert_hf_weight_name(name): + def convert_weight_name_hf2ff(name): assert False, "Not implemented yet" def convert_hf_model(model, dst_folder): assert False, "Not implemented yet" + + def load_weights_into_hf_model(model, src_folder): + assert False, "Not implemented yet" diff --git a/python/flexflow/serve/models/falcon.py b/python/flexflow/serve/models/falcon.py index 0e8fbcbd7d..660b80709c 100644 --- a/python/flexflow/serve/models/falcon.py +++ b/python/flexflow/serve/models/falcon.py @@ -15,6 +15,7 @@ from flexflow.core import * from .base import FlexFlowModel import random, torch +import re class FalconConfig: @@ -244,13 +245,18 @@ def build_model(self, max_tokens_per_batch): self.ffmodel = ffmodel - # TODO: finish this - def convert_hf_weight_name(name): - return (name.replace("transformer.h.", "layers.") + def convert_weight_name_hf2ff(name): + return ( + name.replace("transformer.h.", "layers.") .replace("transformer.", "") .replace("self_attention.dense", "self_attention.o_proj") ) + def convert_weight_name_ff2hf(name): + return "transformer." + name.replace( + "self_attention.o_proj", "self_attention.dense" + ).replace("layers.", "h.") + def convert_hf_model(model, dst_folder): os.makedirs(dst_folder, exist_ok=True) n_head = ( @@ -259,12 +265,18 @@ def convert_hf_model(model, dst_folder): else model.config.num_attention_heads ) for name, params in model.named_parameters(): - name = FlexFlowFalcon.convert_hf_weight_name(name) + name = FlexFlowFalcon.convert_weight_name_hf2ff(name) # Split Q,K,V attention weights if "self_attention.query_key_value" in name: - name_q = name.replace("self_attention.query_key_value", "self_attention.q_proj") - name_k = name.replace("self_attention.query_key_value", "self_attention.k_proj") - name_v = name.replace("self_attention.query_key_value", "self_attention.v_proj") + name_q = name.replace( + "self_attention.query_key_value", "self_attention.q_proj" + ) + name_k = name.replace( + "self_attention.query_key_value", "self_attention.k_proj" + ) + name_v = name.replace( + "self_attention.query_key_value", "self_attention.v_proj" + ) q, k, v = torch.split( params, [ @@ -283,3 +295,151 @@ def convert_hf_model(model, dst_folder): model.lm_head.weight.detach().cpu().numpy().tofile( os.path.join(dst_folder, "lm_head.weight") ) + + def load_weights_into_hf_model(model, src_folder): + """ + Load weights from a specified folder and apply them to a Hugging Face model. + + Parameters: + - model: The instance of the Hugging Face model to load the weights into. + - src_folder: The path to the folder containing the weight files. + - config: The configuration object for the model. + """ + + print(f"loading weights from {model} into {src_folder}") + + hidden_size = model.config.hidden_size + n_head = ( + model.config.n_head + if "n_head" in model.config.__dict__ + else model.config.num_attention_heads + ) + + print("Model hidden size:", hidden_size) + print("Model num_attention_heads:", n_head) + + # num_attention_heads = n_head + # hidden_size_per_head = hidden_size // n_head + # intermediate_size = hidden_size * 4 + + qkv_weights = {} + + for file_name in os.listdir(src_folder): + weight_path = os.path.join(src_folder, file_name) + print("\nProcessing weight file:", weight_path) + if weight_path.endswith("rev_sha.txt"): + print("skipping rev_sha.txt") + continue + else: + original_name = FlexFlowFalcon.convert_weight_name_ff2hf(file_name) + print(f"Converted weight name from {file_name} to {original_name}") + + if not os.path.exists(weight_path): + raise FileNotFoundError(f"No weight file found for {file_name}") + + weight_data = np.fromfile(weight_path, dtype=np.float16).astype(np.float32) + print( + f"Data type after conversion: {weight_data.dtype}, Size: {weight_data.size}" + ) + + # for q,k,v weights, store in dict + if ( + ("q_proj" in original_name) + or ("k_proj" in original_name) + or ("v_proj" in original_name) + ): + + layer_num_match = re.search(r"transformer.h.(\d+)", original_name) + layer_num = int(layer_num_match.group(1)) if layer_num_match else None + qkv_type = file_name.split(".")[-2] + print(f"qkv type for this weight is {qkv_type}") + + if layer_num is not None: + qkv_key = ( + f"transformer.h.{layer_num}.self_attention.query_key_value" + ) + if qkv_key not in qkv_weights: + qkv_weights[qkv_key] = { + "q_proj": None, + "k_proj": None, + "v_proj": None, + } + + qkv_weights[qkv_key][qkv_type] = weight_data + continue + + # Handle non-QKV weights normally + param = model.state_dict()[original_name] + expected_numel = param.numel() + print(f"expected param shape is {expected_numel}") + if param is None: + # raise ValueError(f"Warning: {original_name} not found!") + print(f"Warning: {original_name} not found!") + continue + + if weight_data.size != param.numel(): + # print(f"shape mismatch for {original_name}, model expects {param.numel()} elements, got {weight_data.size}") + expected_shape = param.shape + if weight_data.size % param.numel() == 0: + factor = weight_data.size // np.prod(expected_shape) + new_shape = (factor,) + tuple(expected_shape) + weight_data_reshaped = weight_data.reshape(new_shape)[0] + weight_tensor = torch.from_numpy(weight_data_reshaped) + else: + raise ValueError( + f"Shape mismatch and cannot convert for {original_name}" + ) + else: + weight_tensor = torch.from_numpy(weight_data).reshape(param.shape) + + print(f"shape of the weight tensor is: {weight_tensor.shape}") + with torch.no_grad(): + model.state_dict()[original_name].copy_(weight_tensor) + print(f"Assigned weight {original_name} successfully!\n") + + # Assign combined QKV weights + for qkv_name, weights_dict in qkv_weights.items(): + print("\n========= Processing combined QKV weights ==========") + print( + f"qkv name is {qkv_name}, hidden size is {hidden_size}, number of attention heads is {n_head}" + ) + print( + f"the weights dimensions are: {weights_dict['q_proj'].shape}, {weights_dict['k_proj'].shape}, {weights_dict['v_proj'].shape}" + ) + + q_proj_weight = weights_dict["q_proj"] + k_proj_weight = weights_dict["k_proj"] + v_proj_weight = weights_dict["v_proj"] + + print("Original QKV weights dimensions:") + print("Q:", q_proj_weight.shape) + print("K:", k_proj_weight.shape) + print("V:", v_proj_weight.shape) + + # Reshape the weights to match the expected shape + q_proj_weight_reshaped = q_proj_weight.reshape(-1, hidden_size) + k_proj_weight_reshaped = k_proj_weight.reshape(-1, hidden_size // n_head) + v_proj_weight_reshaped = v_proj_weight.reshape(-1, hidden_size // n_head) + # q_proj_weight_reshaped = q_proj_weight.reshape(k_proj_weight_reshaped.shape[0], -1) + + print("Reshaped QKV weights dimensions:") + print("Q:", q_proj_weight_reshaped.shape) + print("K:", k_proj_weight_reshaped.shape) + print("V:", v_proj_weight_reshaped.shape) + + combined_qkv = np.concatenate( + [ + q_proj_weight_reshaped, + k_proj_weight_reshaped, + v_proj_weight_reshaped, + ], + axis=1, + ) + qkv_weight_name = qkv_name + ".weight" + param_shape = model.state_dict()[qkv_weight_name].shape + print( + f"param shape expected to be {param_shape}, qkv weights combined with weights size {combined_qkv.shape}" + ) + + model.state_dict()[qkv_weight_name].copy_(torch.from_numpy(combined_qkv)) + print(f"Assigned combined QKV weights to {qkv_weight_name}.") diff --git a/python/flexflow/serve/models/llama.py b/python/flexflow/serve/models/llama.py index 96f0258572..e6a08d1563 100644 --- a/python/flexflow/serve/models/llama.py +++ b/python/flexflow/serve/models/llama.py @@ -15,6 +15,10 @@ from flexflow.core import * from .base import FlexFlowModel import random +import re +import os +import numpy as np +import torch class LLAMAConfig: @@ -251,11 +255,64 @@ def build_model(self, max_tokens_per_batch): self.ffmodel = ffmodel - def convert_hf_weight_name(name): + def convert_weight_name_hf2ff(name): return name.replace("model.", "") + def convert_weight_name_ff2hf(name): + if name == "lm_head.weight": + return name + else: + return "model." + name + def convert_hf_model(model, dst_folder): os.makedirs(dst_folder, exist_ok=True) for name, params in model.named_parameters(): - name = FlexFlowLLAMA.convert_hf_weight_name(name) + name = FlexFlowLLAMA.convert_weight_name_hf2ff(name) params.detach().cpu().numpy().tofile(f"{dst_folder}/{name}") + + def load_weights_into_hf_model(model, src_folder): + """ + Load weights from a specified folder and apply them to a Hugging Face model. + + Parameters: + - model: The instance of the Hugging Face model to load weights into. + - src_folder: The path to the folder containing the weight files. + """ + for file_name in os.listdir(src_folder): + weight_path = os.path.join(src_folder, file_name) + if weight_path.endswith("rev_sha.txt"): + print("skipping rev_sha.txt") + continue + original_name = FlexFlowLLAMA.convert_weight_name_ff2hf(file_name) + print(f"Converting weight name: {file_name} to {original_name}") + + if not os.path.exists(weight_path): + raise FileNotFoundError(f"No weight file found for {file_name}") + + ff_dtype = np.float32 if "full-precision" in weight_path else np.float16 + weight_data = np.fromfile( + weight_path, dtype=ff_dtype + ) # .astype(np.float32) + if original_name not in model.state_dict(): + raise KeyError(f"Parameter {original_name} not found in model.") + + param = model.state_dict()[original_name] + expected_numel = param.numel() + if weight_data.size != expected_numel: + print( + f"Adjusting shape for {original_name} from {weight_data.size} to {expected_numel}." + ) + if weight_data.size % expected_numel == 0: + factor = weight_data.size // expected_numel + new_shape = (factor,) + tuple(param.shape) + weight_data_reshaped = weight_data.reshape(new_shape)[0] + weight_tensor = torch.from_numpy(weight_data_reshaped) + else: + raise ValueError( + f"Cannot adjust shape for {original_name} due to incompatible size." + ) + else: + weight_tensor = torch.from_numpy(weight_data).reshape(param.shape) + + with torch.no_grad(): + param.copy_(weight_tensor) diff --git a/python/flexflow/serve/models/mpt.py b/python/flexflow/serve/models/mpt.py index b350ae106d..a17ac42d0d 100644 --- a/python/flexflow/serve/models/mpt.py +++ b/python/flexflow/serve/models/mpt.py @@ -14,7 +14,8 @@ from flexflow.core import * from .base import FlexFlowModel -import random, torch, shutil +import random, torch, shutil, os, re +import numpy as np class MPTConfig: @@ -254,18 +255,22 @@ def build_model(self, max_tokens_per_batch): self.ffmodel = ffmodel - # TODO: finish this - def convert_hf_weight_name(name): + def convert_weight_name_hf2ff(name): return ( name.replace("transformer.blocks.", "layers.") .replace("transformer.", "") .replace("attn.out_proj", "attn.o_proj") ) + def convert_weight_name_ff2hf(name): + return "transformer." + name.replace("attn.o_proj", "attn.out_proj").replace( + "layers.", "blocks." + ) + def convert_hf_model(model, dst_folder): os.makedirs(dst_folder, exist_ok=True) for name, params in model.named_parameters(): - name = FlexFlowMPT.convert_hf_weight_name(name) + name = FlexFlowMPT.convert_weight_name_hf2ff(name) if "Wqkv" in name: name_q = name.replace("attn.Wqkv", "attn.q_proj") name_k = name.replace("attn.Wqkv", "attn.k_proj") @@ -289,3 +294,84 @@ def convert_hf_model(model, dst_folder): os.path.join(dst_folder, "wte.weight"), os.path.join(dst_folder, "lm_head.weight"), ) + + def load_weights_into_hf_model(model, src_folder): + """ + Load weights from a specified folder and apply them to a Hugging Face MPT model. + + Parameters: + - model: The instance of the Hugging Face model to load the weights into. + - src_folder: The path to the folder containing the weight files. + """ + + d_model = model.config.d_model + print("dimension of the model is: ", d_model) + + qkv_weights = {} + + for file_name in os.listdir(src_folder): + weight_path = os.path.join(src_folder, file_name) + if weight_path.endswith("rev_sha.txt"): + print("skipping rev_sha.txt") + continue + elif "lm_head" in weight_path: + print("skipping lm_head.weight") + continue + else: + original_name = FlexFlowMPT.convert_weight_name_ff2hf(file_name) + print("\nconverting weights name of: ", file_name, "to ", original_name) + + if not os.path.exists(weight_path): + raise FileNotFoundError(f"No weight file found for {file_name}") + + weight_data = np.fromfile(weight_path, dtype=np.float32) + print( + f"Data type after conversion: {weight_data.dtype}, Size: {weight_data.size}" + ) + + # Special handling for combined QKV weights + if ( + ("q_proj" in file_name) + or ("k_proj" in file_name) + or ("v_proj" in file_name) + ): + layer_num_match = re.search(r"layers\.(\d+)", original_name) + layer_num = int(layer_num_match.group(1)) if layer_num_match else None + qkv_type = original_name.split("_")[-2] + + if layer_num is not None: + qkv_key = f"layers.{layer_num}.attn_Wqkv" + # initialize qkv layer in dict + if qkv_key not in qkv_weights: + qkv_weights[qkv_key] = {"wq": None, "wk": None, "wv": None} + print(f"Initialized QKV layer {layer_num}") + # assign weights into dict + qkv_weights[qkv_key][qkv_type] = weight_data + + continue + + # for weights that are not q,k,v, get the param names and assign weights accordingly + param = model.state_dict().get(original_name, None) + if weight_data.size != param.numel(): + raise ValueError( + f"Shape mismatch for {original_name}, model expects {param.numel()} elements, got {weight_data.size}" + ) + + weight_tensor = torch.from_numpy(weight_data).reshape(param.shape) + with torch.no_grad(): + model.state_dict()[original_name].copy_(weight_tensor) + + for qkv_key, weights_dict in qkv_weights.items(): + wq, wk, wv = weights_dict["wq"], weights_dict["wk"], weights_dict["wv"] + if None in (wq, wk, wv): + raise ValueError(f"Missing weights for {qkv_key}") + + combined_qkv = np.concatenate([wq, wk, wv], axis=0) + qkv_name = qkv_key.replace("layers.", "transformer.blocks.") + ".weight" + + param_shape = model.state_dict()[qkv_name].shape + combined_qkv_reshaped = combined_qkv.reshape(param_shape) + + model.state_dict()[qkv_name].copy_(torch.from_numpy(combined_qkv_reshaped)) + + print(f"Assigned combined QKV weights to {qkv_key}.") diff --git a/python/flexflow/serve/models/opt.py b/python/flexflow/serve/models/opt.py index 02668abf59..5aaf34ce03 100644 --- a/python/flexflow/serve/models/opt.py +++ b/python/flexflow/serve/models/opt.py @@ -15,6 +15,8 @@ from flexflow.core import * from .base import FlexFlowModel import random, shutil +import re +import torch class OPTConfig: @@ -284,7 +286,7 @@ def build_model(self, max_tokens_per_batch): self.ffmodel = ffmodel - def convert_hf_weight_name(name): + def convert_weight_name_hf2ff(name): return ( name.replace("decoder.", "") .replace("model.", "") @@ -295,13 +297,79 @@ def convert_hf_weight_name(name): ) # important to use the leading "_" to avoid matching the last LayerNorm ) + def convert_weight_name_ff2hf(name): + return ( + ("model.decoder." + name) + .replace(".add_bias_residual_layer_norm", ".final_layer_norm") + .replace("add_bias_residual_layer_norm.attn_bias", "self_attn.o_proj.bias") + .replace("self_attn.o_proj", "self_attn.out_proj") + ) + def convert_hf_model(model, dst_folder): os.makedirs(dst_folder, exist_ok=True) for name, params in model.named_parameters(): - name = FlexFlowOPT.convert_hf_weight_name(name) + name = FlexFlowOPT.convert_weight_name_hf2ff(name) params.detach().cpu().numpy().tofile(f"{dst_folder}/{name}") # copy embedding weights shutil.copy( os.path.join(dst_folder, "embed_tokens.weight"), os.path.join(dst_folder, "lm_head.weight"), ) + + def load_weights_into_hf_model(model, src_folder): + """ + Load weights from a specified folder and apply them to a Hugging Face model. + + This function iterates through the weight files in the specified folder, + converts the FlexFlow weight names to Hugging Face format, and loads the + weights into the Hugging Face model. It handles special cases like shape + mismatches by adjusting the weights accordingly. + + Parameters: + - model: The instance of the Hugging Face model to load the weights into. + - src_folder: The path to the folder containing the weight files. + """ + + for file_name in os.listdir(src_folder): + weight_path = os.path.join(src_folder, file_name) + print("Converting weight name:", weight_path) + + if weight_path.endswith("rev_sha.txt"): + print("Skipping rev_sha.txt") + continue + + original_name = FlexFlowOPT.convert_weight_name_ff2hf(file_name) + print(f"Converting weight name: {file_name} to {original_name}") + if not os.path.exists(weight_path): + raise FileNotFoundError(f"No weight file found for {file_name}") + + ff_dtype = np.float32 if "full-precision" in weight_path else np.float16 + weight_data = np.fromfile( + weight_path, dtype=ff_dtype + ) # .astype(np.float32) + if original_name not in model.state_dict(): + raise KeyError(f"Parameter {original_name} not found in model.") + param = model.state_dict()[original_name] + + # Calculate the reshape size automatically based on expected parameter size + expected_numel = param.numel() + if weight_data.size != expected_numel: + print( + f"Adjusting shape for {original_name} from {weight_data.size} to {expected_numel}" + ) + # Check if weight_data can be evenly divided by expected_numel + if weight_data.size % expected_numel == 0: + # Determine the reshape size + factor = weight_data.size // expected_numel + new_shape = (factor,) + tuple(param.shape) + weight_data_reshaped = weight_data.reshape(new_shape) + weight_tensor = torch.from_numpy(weight_data_reshaped[0]) + else: + raise ValueError( + f"Cannot adjust shape for {original_name} due to incompatible size." + ) + else: + weight_tensor = torch.from_numpy(weight_data).reshape(param.shape) + + with torch.no_grad(): + param.copy_(weight_tensor) diff --git a/python/flexflow/serve/models/starcoder.py b/python/flexflow/serve/models/starcoder.py index 2d4471201f..d52e03aecf 100644 --- a/python/flexflow/serve/models/starcoder.py +++ b/python/flexflow/serve/models/starcoder.py @@ -14,7 +14,8 @@ from flexflow.core import * from .base import FlexFlowModel -import random, torch +import random, torch, re +import numpy as np class STARCODERConfig: @@ -269,3 +270,107 @@ def convert_hf_model(model, dst_folder): model.lm_head.weight.detach().cpu().numpy().tofile( os.path.join(dst_folder, "lm_head.weight") ) + + def convert_weight_name_ff2hf(name): + return "transformer." + name.replace("layers.", "h.") + + def load_weights_into_hf_model(model, src_folder): + """ + Load weights from a specified folder and apply them to a Hugging Face model. + + Parameters: + - model: The instance of the Hugging Face model to load the weights into. + - src_folder: The path to the folder containing the weight files. + """ + + hidden_size = model.config.hidden_size + n_head = ( + model.config.n_head + if "n_head" in model.config.__dict__ + else model.config.num_attention_heads + ) + + print("Model hidden size:", hidden_size) + print("Model num_attention_heads:", n_head) + + qkv_weights = {} + + for file_name in os.listdir(src_folder): + weight_path = os.path.join(src_folder, file_name) + print("\nProcessing weight file:", weight_path) + if weight_path.endswith("rev_sha.txt"): + print("skipping rev_sha.txt") + continue + else: + original_name = FlexFlowSTARCODER.convert_weight_name_ff2hf(file_name) + print(f"Converted weight name: {file_name} to {original_name}") + + if not os.path.exists(weight_path): + raise FileNotFoundError(f"No weight file found for {file_name}") + + weight_data = np.fromfile(weight_path, dtype=np.float32) + print( + f"Data type after conversion: {weight_data.dtype}, Size: {weight_data.size}" + ) + + # Special handling for combined QKV weights + if ( + ("q_proj" in original_name) + or ("k_proj" in original_name) + or ("v_proj" in original_name) + ): + weight_bias = ".weight" if ".weight" in original_name else ".bias" + layer_num_match = re.search(r"layers.(\d+)", file_name) + layer_num = int(layer_num_match.group(1)) if layer_num_match else None + print(f"layer_num is {layer_num}") + qkv_type = file_name.split("_")[-2] + qkv_name = f"transformer.h.{layer_num}.attn.c_attn" + weight_bias + + if layer_num is not None: + # initialize qkv layer in dict + if qkv_name not in qkv_weights: + qkv_weights[qkv_name] = { + "attn.q": None, + "attn.k": None, + "attn.v": None, + } + print(f"Initialized QKV layer {layer_num}") + # assign weights into dict + qkv_weights[qkv_name][qkv_type] = weight_data + print( + f"attached qkv weight {qkv_name} for type {qkv_type}, weight data dimension is {weight_data.shape}" + ) + + continue + + # Handling for other parameters + # for weights that are not q,k,v, get the param names and assign weights accordingly + param = model.state_dict().get(original_name, None) + print(f"Param name: {original_name}") + if weight_data.size != param.numel(): + raise ValueError( + f"Shape mismatch for {original_name}, model expects {param.numel()} elements, got {weight_data.size}" + ) + + weight_tensor = torch.from_numpy(weight_data).reshape(param.shape) + print(f"shape of the weight tensor is: {weight_tensor.shape}") + with torch.no_grad(): + model.state_dict()[original_name].copy_(weight_tensor) + print(f"Assigned weight {original_name} successfully!\n") + + for qkv_name, weights_dict in qkv_weights.items(): + print(f"qkv name is {qkv_name}, with weight {weights_dict}") + combined_qkv = np.concatenate( + [ + qkv_weights[qkv_name]["attn.q"], + qkv_weights[qkv_name]["attn.k"], + qkv_weights[qkv_name]["attn.v"], + ], + axis=0, + ) + param_shape = model.state_dict()[qkv_name].shape + combined_qkv_reshaped = combined_qkv.reshape(param_shape) + print(f"reshaped qkv weights shape is: {combined_qkv_reshaped.shape}") + + model.state_dict()[qkv_name].copy_(torch.from_numpy(combined_qkv_reshaped)) + print(f"Assigned combined QKV weights to {qkv_name}.") diff --git a/python/flexflow/serve/serve.py b/python/flexflow/serve/serve.py index 319505794a..5f7320bc1c 100644 --- a/python/flexflow/serve/serve.py +++ b/python/flexflow/serve/serve.py @@ -29,9 +29,10 @@ from flexflow.core import * from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer from peft import PeftModel, PeftConfig, LoraConfig -from huggingface_hub import HfApi -import torch, shutil, hashlib, json, gc +from huggingface_hub import HfApi, HfFolder, Repository +import torch, shutil, hashlib, json, gc, os from typing import Union, List +import tempfile class _SupportedModels: @@ -97,6 +98,12 @@ def __init__( self.data_type = data_type assert self.data_type == DataType.DT_HALF or self.data_type == DataType.DT_FLOAT self.cache_path = cache_path if len(cache_path) > 0 else "~/.cache/flexflow" + self.weights_path = self.__get_weights_path(self.model_name) + self.tokenizer_path = os.path.join( + os.path.expanduser(self.cache_path), + "tokenizers", + self.model_name.lower(), + ) self.refresh_cache = refresh_cache self.output_file = output_file self.rm = None @@ -213,6 +220,18 @@ def __get_revision_hashes(self, model_name: str, folder: str): latest_revision = hf_api.model_info(self.model_name).sha return ff_revision, ff_revision_file, latest_revision + def __get_weights_path(self, model_name): + return os.path.join( + os.path.expanduser(self.cache_path), + "weights", + model_name.lower(), + ( + "full-precision" + if self.data_type == DataType.DT_FLOAT + else "half-precision" + ), + ) + def download_hf_weights_if_needed(self): """Check in the folder specified by the cache_path whether the LLM's model weights are available and up to date. If not, or if the refresh_cache parameter is set to True, download new weights. @@ -220,20 +239,8 @@ def download_hf_weights_if_needed(self): If any PEFT adapter is registered, perform the same operation for PEFT. """ - def get_weights_path(model_name): - return os.path.join( - os.path.expanduser(self.cache_path), - "weights", - model_name.lower(), - ( - "full-precision" - if self.data_type == DataType.DT_FLOAT - else "half-precision" - ), - ) - def refresh_cache_if_needed(model_name): - weights_path = get_weights_path(model_name) + weights_path = self.__get_weights_path(model_name) if self.refresh_cache: print( f"Refreshing weights in cache for model {model_name} at path {weights_path} ..." @@ -280,7 +287,7 @@ def convert_peft_model(hf_peft_model, peft_type, weights_path): name = name.replace("base_model.model.model.", "").replace( ".default", "" ) - name = self.model_class.convert_hf_weight_name(name) + name = self.model_class.convert_weight_name_hf2ff(name) params.detach().cpu().numpy().tofile(f"{weights_path}/{name}") def download_peft_weights(): @@ -289,7 +296,7 @@ def download_peft_weights(): peft_type = peft_dict["peft_type"] peft_model_id = ff_peft_config.peft_model_id - weights_path = get_weights_path(peft_model_id) + weights_path = self.__get_weights_path(peft_model_id) refresh_cache_if_needed(peft_model_id) ff_revision, ff_revision_file, latest_revision = self.__get_revision_hashes( peft_model_id, weights_path @@ -315,7 +322,6 @@ def download_peft_weights(): gc.collect() torch.cuda.empty_cache() - self.weights_path = get_weights_path(self.model_name) download_llm_weights() download_peft_weights() @@ -364,6 +370,112 @@ def download_hf_tokenizer_if_needed(self): # Save new revision hash to file with open(ff_revision_file, "w+") as f: f.write(latest_revision) + else: + print(f"Loading '{self.model_name}' tokenizer from the cache...") + + def upload_hf_model(self, new_model_id: str, private: bool = False): + """ + Uploads the model to the Hugging Face Hub, with reverse conversion of weights. + + :param new_model_id: The new model ID for the Hugging Face Hub. + :param private: Whether to upload the model as a private model. + """ + # Ensure Hugging Face CLI is logged in + if not HfFolder.get_token(): + raise RuntimeError("Hugging Face token not found. Please login using `huggingface-cli login`.") + + print(f"Preparing model for upload to Hugging Face Hub: {new_model_id}") + print("Tokenizer path is: ", self.tokenizer_path) + + # Initialize a new Hugging Face model instance + hf_model = AutoModelForCausalLM.from_config(self.hf_config) + print(f"Model class is: {self.model_class}") + + # Load FlexFlow weights into the Hugging Face model instance + try: + self.model_class.load_weights_into_hf_model(hf_model, self.weights_path) + except Exception as e: + print(f"Error loading weights into model: {e}") + return + + # Save the model with converted weights to a temporary directory + temp_dir = tempfile.mkdtemp() + hf_model.save_pretrained(temp_dir) + + # Copy the tokenizer files to the temporary directory + tokenizer_files = [f for f in os.listdir(self.tokenizer_path)] + for file_name in tokenizer_files: + shutil.copy(os.path.join(self.tokenizer_path, file_name), temp_dir) + + # Delete rev_sha.txt from the temporary directory if it exists + rev_sha_path = os.path.join(temp_dir, "rev_sha.txt") + if os.path.exists(rev_sha_path): + os.remove(rev_sha_path) + + # Upload the model + api = HfApi() + print(f"Uploading processed model to Hugging Face Hub: {new_model_id}") + api.create_repo(repo_id=new_model_id, private=private, exist_ok=True) + api.upload_folder(folder_path=temp_dir, repo_id=new_model_id) + + # Cleanup temporary directory + shutil.rmtree(temp_dir) + + print("Upload process completed.") + + def upload_peft_model(self, new_model_id: str, private: bool = False): + """ + Uploads the peft model to the Hugging Face Hub, with reverse conversion of weights. + + :param new_model_id: The new model ID for the Hugging Face Hub. + :param private: Whether to upload the model as a private model. + """ + print(f"Preparing model for upload to Hugging Face Hub: {new_model_id}") + print("Tokenizer path is: ", self.tokenizer_path) + + # Initialize a new Hugging Face model instance + hf_model = AutoModelForCausalLM.from_config(self.hf_config) + weights_path = self.weights_path + print(f"Model class is: {self.model_class}") + + # Load FlexFlow weights into the Hugging Face model instance + try: + self.model_class.load_weights_into_hf_model(hf_model, weights_path) + except Exception as e: + print(f"Error loading weights into model: {e}") + return + + # Save the model with converted weights to a temporary directory + temp_dir = tempfile.mkdtemp() + hf_model.save_pretrained(temp_dir) + + # Copy the tokenizer files to the temporary directory + tokenizer_files = [f for f in os.listdir(self.tokenizer_path)] + for file_name in tokenizer_files: + shutil.copy(os.path.join(self.tokenizer_path, file_name), temp_dir) + + # Delete rev_sha.txt from the temporary directory if it exists + rev_sha_path = os.path.join(temp_dir, "rev_sha.txt") + if os.path.exists(rev_sha_path): + os.remove(rev_sha_path) + + # Ensure Hugging Face CLI is logged in + if not HfFolder.get_token(): + print( + "Hugging Face token not found. Please login using `huggingface-cli login`." + ) + return + + # Upload the model + api = HfApi() + print(f"Uploading processed model to Hugging Face Hub: {new_model_id}") + api.create_repo(repo_id=new_model_id, private=private, exist_ok=True) + api.upload_folder(folder_path=temp_dir, repo_id=new_model_id) + + # Cleanup temporary directory + shutil.rmtree(temp_dir) + + print("Upload process completed.") def compile( self, diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index 3ba6398db1..8b76f9f862 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -2727,6 +2727,27 @@ void flexflow_request_manager_terminate_background_server( handle->terminate_background_server(); } +void flexflow_request_manager_save_peft_weights( + flexflow_request_manager_t handle_, + flexflow_model_t model_handle_, + flexflow_peft_model_id_t peft_model_id_, + char const *destination_folder) { + RequestManager *handle = FFCObjectWrapper::unwrap(handle_); + FFModel *model_handle = FFCObjectWrapper::unwrap(model_handle_); + PEFTModelID *peft_model_id = FFCObjectWrapper::unwrap(peft_model_id_); + assert(peft_model_id != nullptr && "PEFT model ID cannot be nullptr"); + assert(destination_folder != nullptr && + "Cannot convert nullptr char * to std::string"); + std::string const destination_folder_str(destination_folder); + DEBUG_PRINT("[RequestManager] save peft weights %p %p %p %s", + handle, + model_handle, + peft_model_id, + destination_folder); + handle->save_peft_weights( + model_handle, *peft_model_id, destination_folder_str); +} + // ----------------------------------------------------------------------- // InferenceManager // ----------------------------------------------------------------------- diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc index d23034bd74..7613d845f1 100644 --- a/src/ops/lora_linear.cc +++ b/src/ops/lora_linear.cc @@ -275,6 +275,176 @@ void LoraLinear::init_inference( set_opmeta_from_futuremap_inference(ff, fm, output_tensor); } +struct LoraLinearSaveWeightsInfo { + LoraLinear const *lora; + PEFTModelID model_id; + int rank; + std::string destination_folder; +}; + +void LoraLinear::save_peft_weights( + FFModel const &ff, + PEFTModelID const &model_id, + int rank, + std::string const &destination_folder, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + assert(check_output_input_weight_same_parallel_is()); + assert(batch_inputs.size() == 2); + assert(batch_outputs.size() == 1); + // Assert that the output and the second input are mapped to the same + // region/part + assert(batch_outputs[0]->region == batch_inputs[1]->region); + assert(batch_outputs[0]->part == batch_inputs[1]->part); + // assert(check_output_input_weight_same_machine_view()); + // output is considered as an input to allow in-place optimization + ParallelTensor output_tensor = batch_outputs[0]; + parallel_is = output_tensor->parallel_is; + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + MachineView const *view = mv ? mv : &output_tensor->machine_view; + size_t machine_view_hash = view->hash(); + set_argumentmap_for_inference(ff, argmap, output_tensor); + LoraLinearSaveWeightsInfo info; + info.lora = this; + info.model_id = model_id; + info.rank = rank; + info.destination_folder = destination_folder; + IndexLauncher launcher(LORA_LINEAR_SAVE_WEIGHTS_TASK_ID, + parallel_is, + TaskArgument(&info, sizeof(LoraLinearSaveWeightsInfo)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + FutureMap fm = runtime->execute_index_space(ctx, launcher); + fm.wait_all_results(); +} + +template +void save_peft_to_file(DT const *weight_ptr, + size_t size, + std::string filepath) { + std::ofstream out(filepath, std::ios::binary); + // Check if the file was opened successfully + if (!out || !out.is_open() || !out.good()) { + printf("Could not open file: %s\n", filepath.c_str()); + } + assert(out && out.is_open() && out.good() && + "can't write to lora weight file path"); + std::vector
host_array(size); + copy_tensor_dev_to_host(weight_ptr, host_array.data(), size); + + size_t target_data_size = sizeof(DT) * size; + out.write((char *)host_array.data(), target_data_size); + + size_t out_written_size = out.tellp(); + if (out_written_size != target_data_size) { + printf("save weight data error: %lu, %lu, %lu\n", + out_written_size, + target_data_size, + sizeof(DT)); + assert(false); + } + out.close(); +} + +void LoraLinear::save_peft_weights_task( + Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + LoraLinearSaveWeightsInfo const *info = + static_cast(task->args); + LoraLinearMeta *m = *((LoraLinearMeta **)task->local_args); + LoraLinear const *lora = info->lora; + + // get shard id + int shard_id = task->index_point.point_data[0]; + + // get dimensions and sizes + int rank = info->rank; + int num_dims = lora->inputs[0]->num_dims; + int in_dim = lora->inputs[0]->dims[0].size / lora->inputs[0]->dims[0].degree; + int out_dim = lora->inputs[1]->dims[0].size / lora->inputs[1]->dims[0].degree; + int w0_num_elements = rank * in_dim; + int w1_num_elements = rank * out_dim; + + // get data type + DataType dt = m->input_type[0]; + assert(dt == m->input_type[1]); + assert(dt == m->output_type[0]); + assert(dt == lora->inputs[0]->data_type); + assert(dt == lora->inputs[1]->data_type); + assert(dt == lora->outputs[0]->data_type); + + // get output filepaths + assert(info->destination_folder.length() > 0 && + "Destination folder is not set"); + struct stat st = {0}; + assert(stat(info->destination_folder.c_str(), &st) == 0 && + (st.st_mode & S_IFDIR) && "Destination folder does not exist"); + assert(lora->name != nullptr && + "Layer name is not set, cannot determine weights location"); + std::string lora_layername = std::string(lora->name); + std::string searchString = "lora"; + size_t found = lora_layername.find(searchString); + if (found == std::string::npos) { + std::cout << "LoraLinear layer name not in the right format (does not " + "contain word 'lora')" + << std::endl; + assert(false); + } + std::string lora_layername_substr = + lora_layername.substr(0, found + searchString.length()); + std::string w0_filepath = + join_path({info->destination_folder, + lora_layername_substr + "_A.weight" + ".shard_" + + std::to_string(shard_id)}); + std::string w1_filepath = join_path( + {info->destination_folder, lora_layername_substr + "_B.weight"}); + + // check handle to peft weights + assert(m->model_weights.find(info->model_id) != m->model_weights.end()); + + // save weights to file + std::cout << "Saving LORA weight " + << lora_layername_substr + "_A.weight" + ".shard_" + + std::to_string(shard_id) + << ", size: " << w0_num_elements << ", shard: " << shard_id + << std::endl; + if (dt == DT_FLOAT) { + save_peft_to_file((float *)m->model_weights[info->model_id].w0_ptr, + w0_num_elements, + w0_filepath); + } else if (dt == DT_HALF) { + save_peft_to_file((half *)m->model_weights[info->model_id].w0_ptr, + w0_num_elements, + w0_filepath); + } else { + assert(false && "Data type not supported"); + } + if (shard_id == 0) { + std::cout << "Saving LORA weight " << lora_layername_substr + "_B.weight" + << ", size: " << w1_num_elements << ", shard: " << shard_id + << std::endl; + if (dt == DT_FLOAT) { + save_peft_to_file((float *)m->model_weights[info->model_id].w1_ptr, + w1_num_elements, + w1_filepath); + } else if (dt == DT_HALF) { + save_peft_to_file((float *)m->model_weights[info->model_id].w1_ptr, + w1_num_elements, + w1_filepath); + } else { + assert(false && "Data type not supported"); + } + } +} + template void load_peft_from_file(DT *ptr, size_t num_rows, diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index c4ee3e1d0b..ac67ace259 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -17,13 +17,16 @@ #include "flexflow/graph.h" #include "flexflow/model.h" #include "flexflow/ops/fused.h" +#include "flexflow/ops/lora_linear.h" #include "flexflow/ops/noop.h" #include "flexflow/parallel_ops/parallel_op.h" #include "flexflow/request_manager.h" +#include namespace FlexFlow { using namespace Legion; +namespace fs = std::filesystem; LegionRuntime::Logger::Category log_inf_mgr("InferenceManager"); LegionRuntime::Logger::Category log_offload("Offloading"); @@ -378,6 +381,59 @@ void InferenceManager::init_operators_inference(FFModel *model) { } } +void InferenceManager::save_peft_weights( + FFModel *model, + PEFTModelID const &model_id, + std::string const &destination_folder) { + // check that peft model id exists and get rank + assert(model->peft_configs.find(model_id) != model->peft_configs.end() && + "PEFT model id is invalid"); + // get rank + int rank = model->peft_configs[model_id].rank; + assert(rank > 0 && "Rank must be greater than 0"); + // Delete the folder if it exists, create it + try { + if (fs::exists(destination_folder) && + fs::is_directory(destination_folder)) { + fs::remove_all(destination_folder); + } + } catch (fs::filesystem_error const &e) { + std::cout << "Error deleting folder: " << e.what() << std::endl; + } + try { + // Create the folder + fs::create_directory(destination_folder); + } catch (fs::filesystem_error const &e) { + std::cout << "Error creating folder: " << e.what() << std::endl; + } + for (size_t o = 0; o < model->operators.size(); o++) { + Op *op = model->operators[o]; + if (op->op_type != OP_LORA) { + continue; + } + std::vector inputs(op->numInputs); + std::vector outputs(op->numOutputs); + for (int i = 0; i < op->numInputs; i++) { + assert(op->inputs[i] != nullptr); + assert(op->inputs[i]->parallel_is != IndexSpace::NO_SPACE); + assert(tensor_buffer[op->inputs[i]].size() > 0); + inputs[i] = tensor_buffer[op->inputs[i]][0]; + assert(inputs[i]->parallel_is != IndexSpace::NO_SPACE); + } + assert(op->numOutputs > 0); + for (int i = 0; i < op->numOutputs; i++) { + assert(op->outputs[i] != nullptr); + assert(op->outputs[i]->parallel_is != IndexSpace::NO_SPACE); + assert(tensor_buffer[op->outputs[i]].size() > 0); + outputs[i] = tensor_buffer[op->outputs[i]][0]; + assert(outputs[i]->parallel_is != IndexSpace::NO_SPACE); + } + LoraLinear *lora = static_cast(model->operators[o]); + lora->save_peft_weights( + *model, model_id, rank, destination_folder, inputs, outputs); + } +} + FutureMap InferenceManager::inference(FFModel *model, int index, BatchConfig const &bc) { diff --git a/src/runtime/model.cc b/src/runtime/model.cc index bf84a621a6..8d96f2e68b 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -6760,6 +6760,22 @@ void register_flexflow_internal_tasks(Runtime *runtime, runtime->register_task_variant(registrar); } } + { + TaskVariantRegistrar registrar(LORA_LINEAR_SAVE_WEIGHTS_TASK_ID, + "LoraLinear Save PEFT Weights"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "LoraLinear Save PEFT Weights Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant( + registrar); + } + } // NoOp { diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 4aeaf8aff0..2a92f0229c 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -2806,6 +2806,14 @@ bool is_peft_operator_type(OperatorType type) { } } +void RequestManager::save_peft_weights(FFModel *model, + PEFTModelID const &model_id, + std::string const &destination_folder) { + // Save the weights of the model + InferenceManager *im = InferenceManager::get_inference_manager(); + im->save_peft_weights(model, model_id, destination_folder); +} + /*static*/ void RequestManager::serve_incr_decoding(FFModel *llm) { diff --git a/tests/upload_test.sh b/tests/upload_test.sh new file mode 100644 index 0000000000..c6b4e3d0f6 --- /dev/null +++ b/tests/upload_test.sh @@ -0,0 +1,60 @@ +#! /usr/bin/env bash +set -x +set -e + +# Cd into directory holding this script +cd "${BASH_SOURCE[0]%/*}" + +# Token to access private huggingface models (e.g. LLAMA-2) +HUGGINGFACE_TOKEN=${HUGGINGFACE_TOKEN:-none} +if [[ "$HUGGINGFACE_TOKEN" != "none" ]]; then + huggingface-cli login --token "$HUGGINGFACE_TOKEN" +fi + +# Create test prompt file +mkdir -p ../inference/prompt +echo '["San Francisco, officially the City and County of San Francisco, is a "]' > ../inference/prompt/test_upload.json + +# Create output folder +mkdir -p ../inference/output +mkdir -p ../inference/configs + +# Enable backtrace in case we run into a segfault or assertion failure +export LEGION_BACKTRACE=1 + +# Create config files +cat > ../inference/configs/llama_small.json < ../inference/configs/llama_small_upload.json <