From 4fa0386325fa5619bf3ec87eef54e57e7f1ebb6f Mon Sep 17 00:00:00 2001 From: LaurentM Pass Date: Wed, 28 Aug 2024 11:32:00 +0200 Subject: [PATCH] refactor: Refacto compliance_model --- .../api/src/pcpapillon/core/compliance_model.py | 16 ++++++++++++---- .../api/src/pcpapillon/utils/constants.py | 1 - .../api/src/pcpapillon/views/compliance.py | 13 +++---------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/apps/fraud/compliance/api/src/pcpapillon/core/compliance_model.py b/apps/fraud/compliance/api/src/pcpapillon/core/compliance_model.py index b939c285..83b0b3d7 100644 --- a/apps/fraud/compliance/api/src/pcpapillon/core/compliance_model.py +++ b/apps/fraud/compliance/api/src/pcpapillon/core/compliance_model.py @@ -4,6 +4,7 @@ import mlflow from main import custom_logger from pcpapillon.utils.constants import ModelName, ModelType +from pcpapillon.utils.data_model import ComplianceInput, ComplianceOutput from pcpapillon.utils.model_handler import ModelHandler, ModelWithMetadata from sentence_transformers import SentenceTransformer @@ -34,22 +35,29 @@ def _load_models( model_name=self.MODEL_NAME, model_type=self.MODEL_TYPE ) - def predict(self, data): + def predict(self, data: ComplianceInput) -> ComplianceOutput: """ Predicts the class labels for the given data using the trained classifier model. Args: - api_config (dict): Configuration parameters for the API. - data (list): Input data to be predicted. + data (ComplianceInput): Input data to be predicted. Returns: + tuple: A tuple containing the predicted class labels and the main contribution. offer validition probability offer rejection probability (=1-proba_val) main features contributing to increase validation probability main features contributing to reduce validation probability """ - return self.model.predict(data) + predictions = self.model.predict(data.dict()) + return ComplianceOutput( + offer_id=data.offer_id, + probability_validated=predictions.probability_validated, + validation_main_features=predictions.validation_main_features, + probability_rejected=predictions.probability_rejected, + rejection_main_features=predictions.rejection_main_features, + ) def _is_newer_model_available(self) -> bool: return ( diff --git a/apps/fraud/compliance/api/src/pcpapillon/utils/constants.py b/apps/fraud/compliance/api/src/pcpapillon/utils/constants.py index 8346b03d..de2e82fd 100644 --- a/apps/fraud/compliance/api/src/pcpapillon/utils/constants.py +++ b/apps/fraud/compliance/api/src/pcpapillon/utils/constants.py @@ -15,7 +15,6 @@ class ModelType(Enum): Enum class for model types """ - LOCAL = "local" DEFAULT = "default" PREPROCESSING = "custom_sentence_transformer" diff --git a/apps/fraud/compliance/api/src/pcpapillon/views/compliance.py b/apps/fraud/compliance/api/src/pcpapillon/views/compliance.py index 61a7bd90..e9018b1c 100644 --- a/apps/fraud/compliance/api/src/pcpapillon/views/compliance.py +++ b/apps/fraud/compliance/api/src/pcpapillon/views/compliance.py @@ -33,14 +33,7 @@ def model_compliance_scoring(scoring_input: ComplianceInput): "scoring_input": scoring_input.dict(), } - predictions = compliance_model.predict(data=scoring_input.dict()) + predictions = compliance_model.predict(data=scoring_input) - validation_response_dict = { - "offer_id": scoring_input.dict()["offer_id"], - "probability_validated": predictions.probability_validated, - "validation_main_features": predictions.validation_main_features, - "probability_rejected": predictions.probability_rejected, - "rejection_main_features": predictions.rejection_main_features, - } - custom_logger.info(validation_response_dict, extra=log_extra_data) - return validation_response_dict + custom_logger.info(predictions.dict(), extra=log_extra_data) + return predictions