diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..b7138ce --- /dev/null +++ b/.env.example @@ -0,0 +1,3 @@ +PEM_PATH= +SERVER_ADDRESS= +DESTINATION_FILE= \ No newline at end of file diff --git a/api.py b/api.py index d108444..00873b3 100644 --- a/api.py +++ b/api.py @@ -21,6 +21,11 @@ from fastapi import File, UploadFile, Form from typing import Annotated import shutil +import subprocess +import os +from dotenv import load_dotenv + +load_dotenv() class Api: @@ -313,6 +318,7 @@ def upload_lora_and_merge_lora_to_checkpoint(self, lora_file: UploadFile, merge_ print("Started to merge lora") merged_res = self.merge_lora(merge_request) + self.copy_checkpoint(merged_res) message = f'Upload and merge lora <{lora_file.filename}> to checkpoint <{merge_request.model}> successfully.' @@ -323,6 +329,19 @@ def upload_lora_and_merge_lora_to_checkpoint(self, lora_file: UploadFile, merge_ raise e # end try + def copy_checkpoint(source_file): + pem_file = os.environ['PEM_PATH'] + server_address = os.environ['SERVER_ADDRESS'] + destination_file = os.environ['DESTINATION_FILE'] + + command = ["sudo", "scp", "-i", pem_file, source_file, server_address + ":" + destination_file] + + try: + subprocess.run(command, check=True) + print("File " + source_file + " copied successfully!") + except subprocess.CalledProcessError as e: + print("Error copying file:", e.output) + raise e def on_app_started(_, app: FastAPI): Api(app, queue_lock, '/supermerger/v1')