Skip to content

Commit

Permalink
refactor: Readd get_model_hash_from_mlflow
Browse files Browse the repository at this point in the history
  • Loading branch information
lmontier-pass committed Aug 28, 2024
1 parent 4fa0386 commit 5005d5b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,8 @@ def predict(self, data: ComplianceInput) -> ComplianceOutput:
data (ComplianceInput): Input data to be predicted.
Returns:
tuple: A tuple containing the predicted class labels and the main contribution.
offer validition probability
offer rejection probability (=1-proba_val)
main features contributing to increase validation probability
main features contributing to reduce validation probability
ComplianceOutput: An object containing the predicted class labels
and the main contributions.
"""
predictions = self.model.predict(data.dict())
return ComplianceOutput(
Expand Down
16 changes: 8 additions & 8 deletions apps/fraud/compliance/api/src/pcpapillon/utils/model_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hashlib
import pickle
import re
from dataclasses import dataclass
from typing import Union

Expand Down Expand Up @@ -30,8 +31,6 @@ def __init__(self) -> None:
def get_model_with_metadata_by_name(
self, model_name: str, model_type=ModelType.DEFAULT
) -> ModelWithMetadata:
print("model_name", model_name)
print("model_type", model_type)
if model_name == ModelName.COMPLIANCE:
loaded_model = mlflow.pyfunc.load_model(
model_uri=f"models:/{self._get_mlflow_model_name(model_name)}"
Expand Down Expand Up @@ -62,14 +61,15 @@ def _get_hash(obj):
return hashlib.md5(obj_bytes).hexdigest()

def get_model_hash_from_mlflow(self, model_name: str):
# mlflow_model_name = self._get_mlflow_model_name(model_name=model_name)
# if not self.mlflow_client:
# raise ValueError("No mlflow client connected")
SPLIT_PATTERN = "/|@"

# model_version = self.mlflow_client.get_latest_versions(mlflow_model_name)
mlflow_model_name = self._get_mlflow_model_name(model_name=model_name)
mlflow_model_name_stripped = re.split(SPLIT_PATTERN, mlflow_model_name)[0]

# return self._get_hash(model_version)
return f"{model_name}-12345"
model_version = self.mlflow_client.get_latest_versions(
mlflow_model_name_stripped
)
return self._get_hash(model_version)

@staticmethod
def _get_mlflow_model_name(model_name: ModelName):
Expand Down

0 comments on commit 5005d5b

Please sign in to comment.