From 48309f5ec54bc779f0d791507d0a85f1aae6b28a Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Wed, 12 Jul 2023 18:30:14 -0400 Subject: [PATCH] lightning: add tests for log_model --- src/dvclive/lightning.py | 14 ++++----- tests/test_frameworks/test_lightning.py | 39 +++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 39361fa7..fef5de26 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -130,7 +130,9 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None): self.experiment.next_step() def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: - self._checkpoint_callback = checkpoint_callback + if self._log_model in [True, "all"]: + self._checkpoint_callback = checkpoint_callback + self._scan_checkpoints(checkpoint_callback) if self._log_model == "all" or ( self._log_model is True and checkpoint_callback.save_top_k == -1 ): @@ -138,13 +140,11 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: @rank_zero_only def finalize(self, status: str) -> None: - checkpoint_callback = self._checkpoint_callback - # Save model checkpoints. - if self._log_model is True: - self._save_checkpoints(checkpoint_callback) # Log best model. - if self._log_model in (True, "all"): - best_model_path = checkpoint_callback.best_model_path + if self._checkpoint_callback: + self._scan_checkpoints(self._checkpoint_callback) + self._save_checkpoints(self._checkpoint_callback) + best_model_path = self._checkpoint_callback.best_model_path self.experiment.log_artifact( best_model_path, name="best", type="model", cache=False ) diff --git a/tests/test_frameworks/test_lightning.py b/tests/test_frameworks/test_lightning.py index 8aa4fbf5..e572b235 100644 --- a/tests/test_frameworks/test_lightning.py +++ b/tests/test_frameworks/test_lightning.py @@ -8,8 +8,9 @@ try: import torch - from pytorch_lightning import LightningModule - from pytorch_lightning.trainer import Trainer + from lightning import LightningModule + from lightning.pytorch import Trainer + from lightning.pytorch.callbacks import ModelCheckpoint from torch import nn from torch.nn import functional as F # noqa: N812 from torch.optim import SGD, Adam @@ -18,7 +19,7 @@ from dvclive import Live from dvclive.lightning import DVCLiveLogger except ImportError: - pytest.skip("skipping pytorch_lightning tests", allow_module_level=True) + pytest.skip("skipping lightning tests", allow_module_level=True) class XORDataset(Dataset): @@ -161,6 +162,38 @@ def test_lightning_kwargs(tmp_dir): assert dvclive_logger.experiment._cache_images is True +@pytest.mark.parametrize("log_model", [False, True, "all"]) +@pytest.mark.parametrize("save_top_k", [1, -1]) +def test_lightning_log_model(tmp_dir, mocker, log_model, save_top_k): + model = LitXOR() + dvclive_logger = DVCLiveLogger(dir="dir", log_model=log_model) + checkpoint = ModelCheckpoint(dirpath="model", save_top_k=save_top_k) + trainer = Trainer( + logger=dvclive_logger, + max_epochs=2, + log_every_n_steps=1, + callbacks=[checkpoint], + ) + log_artifact = mocker.patch.object(dvclive_logger.experiment, "log_artifact") + trainer.fit(model) + + # Check that log_artifact is called. + if log_model is False: + log_artifact.assert_not_called() + elif (log_model is True) and (save_top_k != -1): + # called once to cache, then again to log best artifact + assert log_artifact.call_count == 2 + else: + # once per epoch plus two calls at the end (see above) + assert log_artifact.call_count == 4 + + # Check that checkpoint files does not grow with each run. + num_checkpoints = len(os.listdir(tmp_dir / "model")) + if log_model in [True, "all"]: + trainer.fit(model) + assert len(os.listdir(tmp_dir / "model")) == num_checkpoints + + def test_lightning_steps(tmp_dir, mocker): model = LitXOR() # Handle kwargs passed to Live.