Skip to content

Commit

Permalink
lightning: add tests for log_model
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum committed Jul 12, 2023
1 parent 110b9aa commit 48309f5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 10 deletions.
14 changes: 7 additions & 7 deletions src/dvclive/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,21 +130,21 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None):
self.experiment.next_step()

def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
self._checkpoint_callback = checkpoint_callback
if self._log_model in [True, "all"]:
self._checkpoint_callback = checkpoint_callback
self._scan_checkpoints(checkpoint_callback)
if self._log_model == "all" or (
self._log_model is True and checkpoint_callback.save_top_k == -1
):
self._save_checkpoints(checkpoint_callback)

@rank_zero_only
def finalize(self, status: str) -> None:
checkpoint_callback = self._checkpoint_callback
# Save model checkpoints.
if self._log_model is True:
self._save_checkpoints(checkpoint_callback)
# Log best model.
if self._log_model in (True, "all"):
best_model_path = checkpoint_callback.best_model_path
if self._checkpoint_callback:
self._scan_checkpoints(self._checkpoint_callback)
self._save_checkpoints(self._checkpoint_callback)
best_model_path = self._checkpoint_callback.best_model_path
self.experiment.log_artifact(
best_model_path, name="best", type="model", cache=False
)
Expand Down
39 changes: 36 additions & 3 deletions tests/test_frameworks/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

try:
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.trainer import Trainer
from lightning import LightningModule
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch import nn
from torch.nn import functional as F # noqa: N812
from torch.optim import SGD, Adam
Expand All @@ -18,7 +19,7 @@
from dvclive import Live
from dvclive.lightning import DVCLiveLogger
except ImportError:
pytest.skip("skipping pytorch_lightning tests", allow_module_level=True)
pytest.skip("skipping lightning tests", allow_module_level=True)


class XORDataset(Dataset):
Expand Down Expand Up @@ -161,6 +162,38 @@ def test_lightning_kwargs(tmp_dir):
assert dvclive_logger.experiment._cache_images is True


@pytest.mark.parametrize("log_model", [False, True, "all"])
@pytest.mark.parametrize("save_top_k", [1, -1])
def test_lightning_log_model(tmp_dir, mocker, log_model, save_top_k):
model = LitXOR()
dvclive_logger = DVCLiveLogger(dir="dir", log_model=log_model)
checkpoint = ModelCheckpoint(dirpath="model", save_top_k=save_top_k)
trainer = Trainer(
logger=dvclive_logger,
max_epochs=2,
log_every_n_steps=1,
callbacks=[checkpoint],
)
log_artifact = mocker.patch.object(dvclive_logger.experiment, "log_artifact")
trainer.fit(model)

# Check that log_artifact is called.
if log_model is False:
log_artifact.assert_not_called()
elif (log_model is True) and (save_top_k != -1):
# called once to cache, then again to log best artifact
assert log_artifact.call_count == 2
else:
# once per epoch plus two calls at the end (see above)
assert log_artifact.call_count == 4

# Check that checkpoint files does not grow with each run.
num_checkpoints = len(os.listdir(tmp_dir / "model"))
if log_model in [True, "all"]:
trainer.fit(model)
assert len(os.listdir(tmp_dir / "model")) == num_checkpoints


def test_lightning_steps(tmp_dir, mocker):
model = LitXOR()
# Handle kwargs passed to Live.
Expand Down

0 comments on commit 48309f5

Please sign in to comment.