Skip to content

Commit

Permalink
refactor: Move num_offers_to_return inside predict
Browse files Browse the repository at this point in the history
  • Loading branch information
lmontier-pass committed Sep 23, 2024
1 parent ba05816 commit 62ae50d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@

class OfferCategorisationModel:
MODEL_NAME = ModelName.OFFER_CATEGORISATION
NUM_OFFERS_TO_RETURN = 3

def __init__(self):
self.model_handler = ModelHandler()
model_data = self._load_models()
self.model = model_data.model

def predict(self, data: OfferCategorisationInput) -> OfferCategorisationOutput:
def predict(
self, data: OfferCategorisationInput, num_offers_to_return: int
) -> OfferCategorisationOutput:
"""
Predicts the class labels for the given data using the trained classifier model.
Expand All @@ -31,6 +32,8 @@ def predict(self, data: OfferCategorisationInput) -> OfferCategorisationOutput:
and the main contributions.
"""
predictions = self.model.predict(data.dict())

num_offers_to_return = min(num_offers_to_return, len(predictions.subcategory))
predictions_df = (
pd.DataFrame(
{
Expand All @@ -39,7 +42,7 @@ def predict(self, data: OfferCategorisationInput) -> OfferCategorisationOutput:
}
)
.sort_values("probability", ascending=False)
.iloc[: self.NUM_OFFERS_TO_RETURN]
.iloc[:num_offers_to_return]
)

return OfferCategorisationOutput(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def model_categorisation(input: OfferCategorisationInput):
}

formatted_predictions = offer_categorisation_model.predict(
data=input,
data=input, num_offers_to_return=NUM_OFFERS_TO_RETURN
)

custom_logger.info(formatted_predictions.dict(), extra=log_extra_data)
Expand Down

0 comments on commit 62ae50d

Please sign in to comment.