Skip to content

Commit

Permalink
refactor: Return proper output for offer_categorisation_model
Browse files Browse the repository at this point in the history
  • Loading branch information
lmontier-pass committed Sep 17, 2024
1 parent 59d7e19 commit 6667c4a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pandas as pd
from main import custom_logger
from pcpapillon.utils.constants import (
ModelName,
Expand All @@ -11,6 +12,7 @@

class OfferCategorisationModel:
MODEL_NAME = ModelName.OFFER_CATEGORISATION
NUM_OFFERS_TO_RETURN = 3

def __init__(self):
self.model_handler = ModelHandler()
Expand All @@ -29,9 +31,20 @@ def predict(self, data: OfferCategorisationInput) -> OfferCategorisationOutput:
and the main contributions.
"""
predictions = self.model.predict(data.dict())
print(predictions)
predictions_df = (
pd.DataFrame(
{
"subcategory": predictions.subcategory,
"probability": predictions.probability,
}
)
.sort_values("probability", ascending=False)
.iloc[: self.NUM_OFFERS_TO_RETURN]
)

return OfferCategorisationOutput(most_probable_subcategories=predictions)
return OfferCategorisationOutput(
most_probable_subcategories=predictions_df.to_dict(orient="records")
)

def _load_models(self) -> ModelWithMetadata:
custom_logger.info("Load offer categorisation model..")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def model_categorisation(input: OfferCategorisationInput):
}

formatted_predictions = offer_categorisation_model.predict(
input=input,
data=input,
)

custom_logger.info(formatted_predictions.dict(), extra=log_extra_data)
Expand Down

0 comments on commit 6667c4a

Please sign in to comment.