Skip to content

Commit

Permalink
Fixing tests and request msg
Browse files Browse the repository at this point in the history
  • Loading branch information
Rudiio committed Dec 8, 2023
1 parent 74c8eba commit e6d87c8
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 37 deletions.
24 changes: 20 additions & 4 deletions src/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,19 @@
from functools import wraps
from http import HTTPStatus
from typing import List
import os
import sys

from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from src.models.predict_model import SentiBites
from src.app.schemas import Review

# Get the parent directory
parent_dir = os.path.dirname(os.path.realpath(__file__))

# Add the parent directory to sys.path
sys.path.append(parent_dir)

model = None

Expand Down Expand Up @@ -50,6 +60,7 @@ def _load_model():

global model
model = SentiBites("models/SentiBites/")


@app.get("/", tags=["General"]) # path operation decorator
@construct_response
Expand All @@ -65,19 +76,24 @@ def _index(request: Request):

@app.post("/models/", tags=["Prediction"])
@construct_response
def _predict(request: Request, payload: str):
def _predict(request: Request, payload: Review):
"""Performs sentiment analysis based on the food review."""

if model:
prediction = model.predict(payload)
prediction,scores = model.predict(payload.msg)

response = {
"message": HTTPStatus.OK.phrase,
"status-code": HTTPStatus.OK,
"data": {
"model-type": "RoBERTaSB",
"payload": payload,
"payload": payload.msg,
"prediction": prediction,
"Scores" : {
"positive" : scores['positive'],
"neutral" : scores['neutral'],
"negative" : scores['negative']
}
},
}
else:
Expand Down
4 changes: 4 additions & 0 deletions src/app/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from pydantic import BaseModel

class Review(BaseModel):
msg : str
13 changes: 10 additions & 3 deletions src/models/predict_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,18 @@ def predict(self, text):
scores = output[0][0].detach().numpy()
scores = softmax(scores)

Printing the prediction
Selecting the best score
ranking = np.argsort(scores)
ranking = ranking[::-1]

return self.config.id2label[ranking[0]]

# Stroring the scores
res = {}
for i in range(scores.shape[0]):
length = self.config.id2label[ranking[i]]
score = scores[ranking[i]]
res[length] = float(score)

return self.config.id2label[ranking[0]],res

def preprocess(text):
"""remove links and mentions in a sentence"""
Expand Down
81 changes: 51 additions & 30 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,62 @@
from fastapi.testclient import TestClient
import os
import sys

# Get the parent directory
parent_dir = os.path.dirname(os.path.realpath(__file__))

# Add the parent directory to sys.path
sys.path.append(parent_dir)

from src.app.api import app
from http import HTTPStatus

client = TestClient(app)

def test_read_main():
response = client.get("/")
assert response.status_code == 200
response_body = response.json()
assert response_body["message"] == HTTPStatus.OK.phrase
assert response_body["data"]["message"] == "Welcome to SentiBites! Please, read the `/docs`!"
with TestClient(app) as client:
response = client.get("/")
assert response.status_code == 200
response_body = response.json()
assert response_body["message"] == HTTPStatus.OK.phrase
assert response_body["data"]["message"] == "Welcome to SentiBites! Please, read the `/docs`!"

def test_read_prediction():
response = client.post("/models/", params={"payload": "This is a test."})
assert response.status_code == 200
response_body = response.json()
assert response_body["message"] == HTTPStatus.OK.phrase
assert response_body["status-code"] == HTTPStatus.OK
assert response_body["data"]["model-type"] == "RoBERTaSB"
assert response_body["data"]["payload"] == "This is a test."
with TestClient(app) as client:
response = client.post("/models", json = {'msg':"This is a test."})
assert response.status_code == 200
response_body = response.json()
assert response_body["message"] == HTTPStatus.OK.phrase
assert response_body["status-code"] == HTTPStatus.OK
assert response_body["data"]["model-type"] == "RoBERTaSB"
assert response_body["data"]["payload"] == "This is a test."

def test_positive_prediction():
response = client.post("/models/", params={"payload": "This food is really good."})
assert response.status_code == 200
response_body = response.json()
assert response_body["message"] == HTTPStatus.OK.phrase
assert response_body["status-code"] == HTTPStatus.OK
assert response_body["data"]["model-type"] == "RoBERTaSB"
assert response_body["data"]["payload"] == "This food is really good."
assert response_body["data"]["prediction"] == "positive"
with TestClient(app) as client:
response = client.post("/models/", json={"msg": "This food is really good."})
assert response.status_code == 200
response_body = response.json()
assert response_body["message"] == HTTPStatus.OK.phrase
assert response_body["status-code"] == HTTPStatus.OK
assert response_body["data"]["model-type"] == "RoBERTaSB"
assert response_body["data"]["payload"] == "This food is really good."
assert response_body["data"]["prediction"] == "positive"

def test_negative_prediction():
response = client.post("/models/", params={"payload": "Never buying this again."})
assert response.status_code == 200
response_body = response.json()
assert response_body["message"] == HTTPStatus.OK.phrase
assert response_body["status-code"] == HTTPStatus.OK
assert response_body["data"]["model-type"] == "RoBERTaSB"
assert response_body["data"]["payload"] == "Never buying this again."
assert response_body["data"]["prediction"] == "negative"
with TestClient(app) as client:
response = client.post("/models/", json={"msg": "Never buying this again."})
assert response.status_code == 200
response_body = response.json()
assert response_body["message"] == HTTPStatus.OK.phrase
assert response_body["status-code"] == HTTPStatus.OK
assert response_body["data"]["model-type"] == "RoBERTaSB"
assert response_body["data"]["payload"] == "Never buying this again."
assert response_body["data"]["prediction"] == "negative"

def test_bad_url():
with TestClient(app) as client:
response = client.post("/mode", json={"msg": "Never buying this again."})
assert response.status_code == 404

def test_bad_request():
with TestClient(app) as client:
response = client.post("/models/", json={"false": "Never buying this again."})
assert response.status_code == 422

0 comments on commit e6d87c8

Please sign in to comment.