diff --git a/src/app/api.py b/src/app/api.py index 0cc3e3a..706b43c 100644 --- a/src/app/api.py +++ b/src/app/api.py @@ -5,9 +5,19 @@ from functools import wraps from http import HTTPStatus from typing import List +import os +import sys +from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, Request from src.models.predict_model import SentiBites +from src.app.schemas import Review + +# Get the parent directory +parent_dir = os.path.dirname(os.path.realpath(__file__)) + +# Add the parent directory to sys.path +sys.path.append(parent_dir) model = None @@ -50,6 +60,7 @@ def _load_model(): global model model = SentiBites("models/SentiBites/") + @app.get("/", tags=["General"]) # path operation decorator @construct_response @@ -65,19 +76,24 @@ def _index(request: Request): @app.post("/models/", tags=["Prediction"]) @construct_response -def _predict(request: Request, payload: str): +def _predict(request: Request, payload: Review): """Performs sentiment analysis based on the food review.""" - + if model: - prediction = model.predict(payload) + prediction,scores = model.predict(payload.msg) response = { "message": HTTPStatus.OK.phrase, "status-code": HTTPStatus.OK, "data": { "model-type": "RoBERTaSB", - "payload": payload, + "payload": payload.msg, "prediction": prediction, + "Scores" : { + "positive" : scores['positive'], + "neutral" : scores['neutral'], + "negative" : scores['negative'] + } }, } else: diff --git a/src/app/schemas.py b/src/app/schemas.py new file mode 100644 index 0000000..19693f8 --- /dev/null +++ b/src/app/schemas.py @@ -0,0 +1,4 @@ +from pydantic import BaseModel + +class Review(BaseModel): + msg : str \ No newline at end of file diff --git a/src/models/predict_model.py b/src/models/predict_model.py index 91c70e5..8a91f64 100755 --- a/src/models/predict_model.py +++ b/src/models/predict_model.py @@ -28,11 +28,18 @@ def predict(self, text): scores = output[0][0].detach().numpy() scores = softmax(scores) - # Printing the prediction + # Selecting the best score ranking = np.argsort(scores) ranking = ranking[::-1] - - return self.config.id2label[ranking[0]] + + # Stroring the scores + res = {} + for i in range(scores.shape[0]): + length = self.config.id2label[ranking[i]] + score = scores[ranking[i]] + res[length] = float(score) + + return self.config.id2label[ranking[0]],res def preprocess(text): """remove links and mentions in a sentence""" diff --git a/tests/test_api.py b/tests/test_api.py index 1302db1..93eb43d 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,41 +1,62 @@ from fastapi.testclient import TestClient +import os +import sys + +# Get the parent directory +parent_dir = os.path.dirname(os.path.realpath(__file__)) + +# Add the parent directory to sys.path +sys.path.append(parent_dir) + from src.app.api import app from http import HTTPStatus -client = TestClient(app) - def test_read_main(): - response = client.get("/") - assert response.status_code == 200 - response_body = response.json() - assert response_body["message"] == HTTPStatus.OK.phrase - assert response_body["data"]["message"] == "Welcome to SentiBites! Please, read the `/docs`!" + with TestClient(app) as client: + response = client.get("/") + assert response.status_code == 200 + response_body = response.json() + assert response_body["message"] == HTTPStatus.OK.phrase + assert response_body["data"]["message"] == "Welcome to SentiBites! Please, read the `/docs`!" def test_read_prediction(): - response = client.post("/models/", params={"payload": "This is a test."}) - assert response.status_code == 200 - response_body = response.json() - assert response_body["message"] == HTTPStatus.OK.phrase - assert response_body["status-code"] == HTTPStatus.OK - assert response_body["data"]["model-type"] == "RoBERTaSB" - assert response_body["data"]["payload"] == "This is a test." + with TestClient(app) as client: + response = client.post("/models", json = {'msg':"This is a test."}) + assert response.status_code == 200 + response_body = response.json() + assert response_body["message"] == HTTPStatus.OK.phrase + assert response_body["status-code"] == HTTPStatus.OK + assert response_body["data"]["model-type"] == "RoBERTaSB" + assert response_body["data"]["payload"] == "This is a test." def test_positive_prediction(): - response = client.post("/models/", params={"payload": "This food is really good."}) - assert response.status_code == 200 - response_body = response.json() - assert response_body["message"] == HTTPStatus.OK.phrase - assert response_body["status-code"] == HTTPStatus.OK - assert response_body["data"]["model-type"] == "RoBERTaSB" - assert response_body["data"]["payload"] == "This food is really good." - assert response_body["data"]["prediction"] == "positive" + with TestClient(app) as client: + response = client.post("/models/", json={"msg": "This food is really good."}) + assert response.status_code == 200 + response_body = response.json() + assert response_body["message"] == HTTPStatus.OK.phrase + assert response_body["status-code"] == HTTPStatus.OK + assert response_body["data"]["model-type"] == "RoBERTaSB" + assert response_body["data"]["payload"] == "This food is really good." + assert response_body["data"]["prediction"] == "positive" def test_negative_prediction(): - response = client.post("/models/", params={"payload": "Never buying this again."}) - assert response.status_code == 200 - response_body = response.json() - assert response_body["message"] == HTTPStatus.OK.phrase - assert response_body["status-code"] == HTTPStatus.OK - assert response_body["data"]["model-type"] == "RoBERTaSB" - assert response_body["data"]["payload"] == "Never buying this again." - assert response_body["data"]["prediction"] == "negative" \ No newline at end of file + with TestClient(app) as client: + response = client.post("/models/", json={"msg": "Never buying this again."}) + assert response.status_code == 200 + response_body = response.json() + assert response_body["message"] == HTTPStatus.OK.phrase + assert response_body["status-code"] == HTTPStatus.OK + assert response_body["data"]["model-type"] == "RoBERTaSB" + assert response_body["data"]["payload"] == "Never buying this again." + assert response_body["data"]["prediction"] == "negative" + +def test_bad_url(): + with TestClient(app) as client: + response = client.post("/mode", json={"msg": "Never buying this again."}) + assert response.status_code == 404 + +def test_bad_request(): + with TestClient(app) as client: + response = client.post("/models/", json={"false": "Never buying this again."}) + assert response.status_code == 422 \ No newline at end of file