diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index d728f6bd..8482c6e8 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -10,7 +10,6 @@ from mlflow.store.entities import PagedList from sklearn.preprocessing import StandardScaler -from mlflow.models.model import ModelInfo from numalogic.models.autoencoder.variants import VanillaAE from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache @@ -29,7 +28,7 @@ mock_list_of_model_version, mock_list_of_model_version2, return_sklearn_rundata, - mock_get_model_version_obj + mock_get_model_version_obj, ) TRACKING_URI = "http://0.0.0.0:5009" @@ -56,7 +55,6 @@ def test_construct_key(self): key = MLflowRegistry.construct_key(skeys, dkeys) self.assertEqual("model_:nnet::error1", key) - @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) @patch("mlflow.log_param", mock_log_state_dict) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @@ -74,7 +72,6 @@ def test_save_model(self): mock_status = "READY" self.assertEqual(mock_status, status.status) - @patch("mlflow.sklearn.log_model", mock_log_model_sklearn) @patch("mlflow.log_param", mock_log_state_dict) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_sklearn_rundata()))) @@ -87,19 +84,15 @@ def test_save_model_sklearn(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys - status = ml.save(skeys=skeys, - dkeys=dkeys, - artifact=model, - artifact_type="sklearn") - + status = ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="sklearn") + mock_status = "READY" self.assertEqual(mock_status, status.status) - @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) - @patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate",0.01)]))) + @patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate", 0.01)]))) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @@ -110,16 +103,10 @@ def test_load_model_when_pytorch_model_exist1(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys - ml.save(skeys=skeys, - dkeys=dkeys, - artifact=model, - **{"lr": 0.01}, - artifact_type="pytorch") + ml.save(skeys=skeys, dkeys=dkeys, artifact=model, **{"lr": 0.01}, artifact_type="pytorch") data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch") self.assertIsNotNone(data.metadata) self.assertIsInstance(data.artifact, VanillaAE) - - @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @@ -134,10 +121,7 @@ def test_load_model_when_pytorch_model_exist2(self): ml = MLflowRegistry(TRACKING_URI, models_to_retain=2) skeys = self.skeys dkeys = self.dkeys - ml.save(skeys=skeys, - dkeys=dkeys, - artifact=model, - artifact_type="pytorch") + ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="pytorch") data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch") self.assertEqual(data.metadata, {}) self.assertIsInstance(data.artifact, VanillaAE) @@ -169,8 +153,6 @@ def test_load_model_when_sklearn_model_exist(self): self.assertIsInstance(data.artifact, StandardScaler) self.assertEqual(data.metadata, {}) - - @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_empty_rundata()))) @patch("mlflow.active_run", Mock(return_value=return_empty_rundata())) @@ -190,8 +172,6 @@ def test_load_model_with_version(self): self.assertIsInstance(data.artifact, VanillaAE) self.assertEqual(data.metadata, {}) - - @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch( @@ -207,10 +187,9 @@ def test_staging_model_load_error(self): with self.assertLogs(level="ERROR") as log: result = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch") self.assertIsNone(result) # Ensure the result is None - self.assertTrue(any("No Model found" in message for message in log.output)) # Check that the expected log was made - - - + self.assertTrue( + any("No Model found" in message for message in log.output) + ) # Check that the expected log was made @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @@ -223,7 +202,6 @@ def test_both_version_latest_model_with_version(self): dkeys = self.dkeys with self.assertRaises(ValueError): ml.load(skeys=skeys, dkeys=dkeys, latest=False, artifact_type="pytorch") - @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @@ -280,8 +258,6 @@ def test_no_implementation(self): ml.load(skeys=fake_skeys, dkeys=fake_dkeys) self.assertTrue(log.output) - - @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) @@ -356,12 +332,10 @@ def test_load_other_mlflow_err(self): dkeys = self.dkeys self.assertIsNone(ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")) - - @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) - @patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate",0.01)]))) + @patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate", 0.01)]))) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @@ -380,11 +354,10 @@ def test_is_model_stale_true(self): data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch") self.assertTrue(ml.is_artifact_stale(data, 12)) - @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) - @patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate",0.01)]))) + @patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate", 0.01)]))) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @@ -421,8 +394,6 @@ def test_cache(self): self.assertIsNotNone(registry._load_from_cache("key")) self.assertIsNotNone(registry._clear_cache("key")) - - @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @@ -437,6 +408,5 @@ def test_cache_loading(self): self.assertIsNotNone(ml._load_from_cache(key)) - if __name__ == "__main__": unittest.main()