diff --git a/api/main.py b/api/main.py index 5943847..11465f8 100644 --- a/api/main.py +++ b/api/main.py @@ -15,6 +15,10 @@ _VERSION: str = "0.1.107" _S3_BUCKET_NAME_FOR_INPUTS: str = "voicecloning-inputs" +BATCH_CLIENT = boto3.client("batch") +BATCH_JOB_QUEUE = "batch-fargate-voicecloning-job-queue" +BATCH_JOB_DEFINITION = "aws-batch-job-definition-for-fargate-for-voicecloning" + app = FastAPI( title="VoiceCloning API", description="VoiceCloning API", @@ -51,7 +55,7 @@ class Project(ProjectBase): # This includes the id, created_at and audio_files, for responses id: str created_at: str - progress: str + # progress: str audio_files: List[str] @@ -83,7 +87,7 @@ async def get_projects() -> List[Project]: "text": item["text"]["S"], "quality": item["quality"]["S"], "created_at": item["created_at"]["S"], - "progress": item["progress"]["S"], + # "progress": item["progress"]["S"], # Convert string to list "audio_files": ast.literal_eval(item["audio_files"]["S"]), } @@ -109,7 +113,7 @@ async def get_project(project_id: str) -> Project: "text": response["text"]["S"], "quality": response["quality"]["S"], "created_at": response["created_at"]["S"], - "progress": response["progress"]["S"], + # "progress": response["progress"]["S"], "audio_files": ast.literal_eval(response["audio_files"]["S"]), } return Project(**project_data) @@ -140,18 +144,25 @@ async def create_project( "quality": quality, "id": unique_project_id, "created_at": time_now, - "progress": "ongoing", # options: 'ongoing', 'failed', 'ready' "audio_files": [audio_file.filename for audio_file in audio_files], + # NOT IMPLEMENTED YET + # "progress": "ongoing", # options: 'ongoing', 'failed', 'ready' } # Assuming you have a database client named db_client - response = db_client.put_item( + response_db = db_client.put_item( TableName=_TABLE_NAME, Item={**{key: {"S": str(value)} for key, value in project_data.items()}}, ) + try: + response_batch = _submit_batch_job(unique_project_id) + except Exception as e: + print(f"Failed to submit batch job: {e}") + return { - "status": response["ResponseMetadata"]["HTTPStatusCode"], + "status_for_db": response_db["ResponseMetadata"]["HTTPStatusCode"], + "status_for_batch": response_batch["ResponseMetadata"]["HTTPStatusCode"], "project": project_data, } @@ -194,3 +205,18 @@ def _process_and_upload_files(audio_files: List[UploadFile], project_id: str) -> for audio_file in audio_files: presigned_url = _generate_presigned_url(project_id, audio_file) _upload_to_s3(audio_file=audio_file, presigned_url=presigned_url) + + +def _submit_batch_job(project_id: str) -> str: + """Submit a batch job to AWS Batch.""" + response = BATCH_CLIENT.submit_job( + jobName=f"VoiceCloning-job-{str(uuid.uuid4().hex)}", + jobQueue=BATCH_JOB_QUEUE, + jobDefinition=BATCH_JOB_DEFINITION, + containerOverrides={ + "environment": [{"name": "PROJECT_ID", "value": project_id}] + }, + ) + + print(f"Job submitted. Job ID: {response['jobId']}") + return response["jobId"] \ No newline at end of file