Skip to content

Commit

Permalink
fix: imports
Browse files Browse the repository at this point in the history
Signed-off-by: Leila Wang <leilawang@cs.toronto.edu>
  • Loading branch information
yleilawang committed Sep 26, 2024
1 parent 74c10bb commit 36fd6a4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
11 changes: 5 additions & 6 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@
from enum import Enum
from typing import Optional, Any

import mlflow.pyfunc
import mlflow.pytorch
import mlflow.sklearn
import mlflow
from mlflow.entities.model_registry import ModelVersion
from mlflow.exceptions import RestException
from mlflow.exceptions import RestException, MlflowException
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST
from mlflow.tracking import MlflowClient

Expand Down Expand Up @@ -210,7 +207,7 @@ def load_multiple(

try:
unwrapped_composite_model = loaded_model.artifact.unwrap_python_model()
except mlflow.exceptions.MlflowException as e:
except MlflowException as e:
raise TypeError("The loaded model is not a valid pyfunc Python model.") from e
except AttributeError:
_LOGGER.exception("The loaded model does not have an unwrap_python_model method")
Expand Down Expand Up @@ -303,7 +300,9 @@ def save_multiple(
"""
if len(dict_artifacts) == 1:
_LOGGER.warning("Only one element in dict_artifacts. Saving directly is recommended.")
_LOGGER.warning(
"Only one artifact present in dict_artifacts. Saving directly is recommended."
)
multiple_artifacts = CompositeModel(skeys=skeys, dict_artifacts=dict_artifacts, **metadata)
return self.save(
skeys=skeys,
Expand Down
3 changes: 3 additions & 0 deletions tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from contextlib import contextmanager
from unittest.mock import patch, Mock

import mlflow.pytorch # noqa: F401
import mlflow.pyfunc # noqa: F401
import mlflow.sklearn # noqa: F401
from freezegun import freeze_time
from mlflow import ActiveRun
from mlflow.exceptions import RestException
Expand Down

0 comments on commit 36fd6a4

Please sign in to comment.