Skip to content

Commit

Permalink
fix: lint
Browse files Browse the repository at this point in the history
Signed-off-by: Leila <leilawang@cs.toronto.edu>
  • Loading branch information
yleilawang committed Sep 18, 2024
1 parent 5a43c33 commit 1d5cbbc
Showing 1 changed file with 11 additions and 41 deletions.
52 changes: 11 additions & 41 deletions tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from mlflow.store.entities import PagedList
from sklearn.preprocessing import StandardScaler

from mlflow.models.model import ModelInfo
from numalogic.models.autoencoder.variants import VanillaAE
from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache

Expand All @@ -29,7 +28,7 @@
mock_list_of_model_version,
mock_list_of_model_version2,
return_sklearn_rundata,
mock_get_model_version_obj
mock_get_model_version_obj,
)

TRACKING_URI = "http://0.0.0.0:5009"
Expand All @@ -56,7 +55,6 @@ def test_construct_key(self):
key = MLflowRegistry.construct_key(skeys, dkeys)
self.assertEqual("model_:nnet::error1", key)


@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
Expand All @@ -74,7 +72,6 @@ def test_save_model(self):
mock_status = "READY"
self.assertEqual(mock_status, status.status)


@patch("mlflow.sklearn.log_model", mock_log_model_sklearn)
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_sklearn_rundata())))
Expand All @@ -87,19 +84,15 @@ def test_save_model_sklearn(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
status = ml.save(skeys=skeys,
dkeys=dkeys,
artifact=model,
artifact_type="sklearn")

status = ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="sklearn")

mock_status = "READY"
self.assertEqual(mock_status, status.status)


@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate",0.01)])))
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate", 0.01)])))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
Expand All @@ -110,16 +103,10 @@ def test_load_model_when_pytorch_model_exist1(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
ml.save(skeys=skeys,
dkeys=dkeys,
artifact=model,
**{"lr": 0.01},
artifact_type="pytorch")
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, **{"lr": 0.01}, artifact_type="pytorch")
data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
self.assertIsNotNone(data.metadata)
self.assertIsInstance(data.artifact, VanillaAE)



@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
Expand All @@ -134,10 +121,7 @@ def test_load_model_when_pytorch_model_exist2(self):
ml = MLflowRegistry(TRACKING_URI, models_to_retain=2)
skeys = self.skeys
dkeys = self.dkeys
ml.save(skeys=skeys,
dkeys=dkeys,
artifact=model,
artifact_type="pytorch")
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="pytorch")
data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
self.assertEqual(data.metadata, {})
self.assertIsInstance(data.artifact, VanillaAE)
Expand Down Expand Up @@ -169,8 +153,6 @@ def test_load_model_when_sklearn_model_exist(self):
self.assertIsInstance(data.artifact, StandardScaler)
self.assertEqual(data.metadata, {})



@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_empty_rundata())))
@patch("mlflow.active_run", Mock(return_value=return_empty_rundata()))
Expand All @@ -190,8 +172,6 @@ def test_load_model_with_version(self):
self.assertIsInstance(data.artifact, VanillaAE)
self.assertEqual(data.metadata, {})



@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch(
Expand All @@ -207,10 +187,9 @@ def test_staging_model_load_error(self):
with self.assertLogs(level="ERROR") as log:
result = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
self.assertIsNone(result) # Ensure the result is None
self.assertTrue(any("No Model found" in message for message in log.output)) # Check that the expected log was made



self.assertTrue(
any("No Model found" in message for message in log.output)
) # Check that the expected log was made

@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
Expand All @@ -223,7 +202,6 @@ def test_both_version_latest_model_with_version(self):
dkeys = self.dkeys
with self.assertRaises(ValueError):
ml.load(skeys=skeys, dkeys=dkeys, latest=False, artifact_type="pytorch")


@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
Expand Down Expand Up @@ -280,8 +258,6 @@ def test_no_implementation(self):
ml.load(skeys=fake_skeys, dkeys=fake_dkeys)
self.assertTrue(log.output)



@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
Expand Down Expand Up @@ -356,12 +332,10 @@ def test_load_other_mlflow_err(self):
dkeys = self.dkeys
self.assertIsNone(ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch"))



@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate",0.01)])))
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate", 0.01)])))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
Expand All @@ -380,11 +354,10 @@ def test_is_model_stale_true(self):
data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
self.assertTrue(ml.is_artifact_stale(data, 12))


@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate",0.01)])))
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate", 0.01)])))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
Expand Down Expand Up @@ -421,8 +394,6 @@ def test_cache(self):
self.assertIsNotNone(registry._load_from_cache("key"))
self.assertIsNotNone(registry._clear_cache("key"))



@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
Expand All @@ -437,6 +408,5 @@ def test_cache_loading(self):
self.assertIsNotNone(ml._load_from_cache(key))



if __name__ == "__main__":
unittest.main()

0 comments on commit 1d5cbbc

Please sign in to comment.