diff --git a/whisper-image/app/predictor.py b/whisper-image/app/predictor.py index 43475b6..17a3736 100644 --- a/whisper-image/app/predictor.py +++ b/whisper-image/app/predictor.py @@ -17,7 +17,8 @@ app = flask.Flask(__name__) s3_client = boto3.client("s3") -model_name = "medium.en" +default_model_name = "turbo" + @app.route("/ping", methods=["GET"]) def ping(): @@ -32,10 +33,10 @@ def execution_parameters(): status = 200 return flask.Response(response="{}", status=status, mimetype="application/json") + @app.route("/invocations", methods=["POST"]) def transformation(): - """Do an inference on a single batch of data. - """ + """Do an inference on a single batch of data.""" content_type = flask.request.content_type request_data = flask.request.data logger.info(f"transformation: {content_type} {request_data}") @@ -47,11 +48,14 @@ def transformation(): input_dict = json.loads(data) else: return flask.Response( - response="The predictor only supports application/json content type", status=415, mimetype="text/plain" + response="The predictor only supports application/json content type", + status=415, + mimetype="text/plain", ) bucket_name = input_dict["bucket_name"] object_key = input_dict["object_key"] + model_name = input_dict.get("model_name", default_model_name) fd, filename = tempfile.mkstemp() try: os.close(fd) @@ -66,9 +70,6 @@ def transformation(): finally: os.unlink(filename) - payload = { - **input_dict, - "result": result - } + payload = {**input_dict, "result": result} response = json.dumps(payload) return flask.Response(response=response, status=200, mimetype="application/json") diff --git a/whisper-image/requirements.txt b/whisper-image/requirements.txt index 56997e8..dddabdf 100644 --- a/whisper-image/requirements.txt +++ b/whisper-image/requirements.txt @@ -1,4 +1,4 @@ boto3==1.24.96 Flask==3.0.1 -gunicorn==21.2.0 -openai-whisper==20230918 +gunicorn==21.2.0 +openai-whisper==20240930