Skip to content

Commit

Permalink
refactor: Refacto compliance_model
Browse files Browse the repository at this point in the history
  • Loading branch information
lmontier-pass committed Aug 28, 2024
1 parent ebf5701 commit 4fa0386
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
16 changes: 12 additions & 4 deletions apps/fraud/compliance/api/src/pcpapillon/core/compliance_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import mlflow
from main import custom_logger
from pcpapillon.utils.constants import ModelName, ModelType
from pcpapillon.utils.data_model import ComplianceInput, ComplianceOutput
from pcpapillon.utils.model_handler import ModelHandler, ModelWithMetadata
from sentence_transformers import SentenceTransformer

Expand Down Expand Up @@ -34,22 +35,29 @@ def _load_models(
model_name=self.MODEL_NAME, model_type=self.MODEL_TYPE
)

def predict(self, data):
def predict(self, data: ComplianceInput) -> ComplianceOutput:
"""
Predicts the class labels for the given data using the trained classifier model.
Args:
api_config (dict): Configuration parameters for the API.
data (list): Input data to be predicted.
data (ComplianceInput): Input data to be predicted.
Returns:
tuple: A tuple containing the predicted class labels and the main contribution.
offer validition probability
offer rejection probability (=1-proba_val)
main features contributing to increase validation probability
main features contributing to reduce validation probability
"""
return self.model.predict(data)
predictions = self.model.predict(data.dict())
return ComplianceOutput(
offer_id=data.offer_id,
probability_validated=predictions.probability_validated,
validation_main_features=predictions.validation_main_features,
probability_rejected=predictions.probability_rejected,
rejection_main_features=predictions.rejection_main_features,
)

def _is_newer_model_available(self) -> bool:
return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class ModelType(Enum):
Enum class for model types
"""

LOCAL = "local"
DEFAULT = "default"
PREPROCESSING = "custom_sentence_transformer"

Expand Down
13 changes: 3 additions & 10 deletions apps/fraud/compliance/api/src/pcpapillon/views/compliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,7 @@ def model_compliance_scoring(scoring_input: ComplianceInput):
"scoring_input": scoring_input.dict(),
}

predictions = compliance_model.predict(data=scoring_input.dict())
predictions = compliance_model.predict(data=scoring_input)

validation_response_dict = {
"offer_id": scoring_input.dict()["offer_id"],
"probability_validated": predictions.probability_validated,
"validation_main_features": predictions.validation_main_features,
"probability_rejected": predictions.probability_rejected,
"rejection_main_features": predictions.rejection_main_features,
}
custom_logger.info(validation_response_dict, extra=log_extra_data)
return validation_response_dict
custom_logger.info(predictions.dict(), extra=log_extra_data)
return predictions

0 comments on commit 4fa0386

Please sign in to comment.