Skip to content

Commit

Permalink
lightning: drop unused checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum committed Jul 11, 2023
1 parent 3c77b75 commit 110b9aa
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions src/dvclive/lightning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ruff: noqa: ARG002
import inspect
from typing import Any, Dict, Optional, Union
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from lightning.fabric.utilities.logger import (
_convert_params,
Expand All @@ -9,6 +10,7 @@
)
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.loggers.utilities import _scan_checkpoints
from lightning.pytorch.utilities import rank_zero_only
from torch import is_tensor

Expand All @@ -35,7 +37,7 @@ def _should_call_next_step():


class DVCLiveLogger(Logger):
def __init__(
def __init__( # noqa: PLR0913
self,
run_name: Optional[str] = "dvclive_run",
prefix="",
Expand Down Expand Up @@ -65,7 +67,9 @@ def __init__(
# Force Live instantiation
self.experiment # noqa: B018
self._log_model = log_model
self._logged_model_time: Dict[str, float] = {}
self._checkpoint_callback: Optional[ModelCheckpoint] = None
self._all_checkpoint_paths: List[str] = []

@property
def name(self):
Expand Down Expand Up @@ -130,18 +134,37 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
if self._log_model == "all" or (
self._log_model is True and checkpoint_callback.save_top_k == -1
):
self.experiment.log_artifact(checkpoint_callback.dirpath)
self._save_checkpoints(checkpoint_callback)

@rank_zero_only
def finalize(self, status: str) -> None:
checkpoint_callback = self._checkpoint_callback
# Save model checkpoints.
if self._log_model is True:
self.experiment.log_artifact(checkpoint_callback.dirpath)
self._save_checkpoints(checkpoint_callback)
# Log best model.
if self._log_model in (True, "all"):
best_model_path = checkpoint_callback.best_model_path
self.experiment.log_artifact(
best_model_path, name="best", type="model", cache=False
)
self.experiment.end()

def _scan_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
# get checkpoints to be saved with associated score
checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)

# update model time and append path to list of all checkpoints
for t, p, _, _ in checkpoints:
self._logged_model_time[p] = t
self._all_checkpoint_paths.append(p)

def _save_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
# drop unused checkpoints
if not self._experiment._resume: # noqa: SLF001
for p in Path(checkpoint_callback.dirpath).iterdir():
if str(p) not in self._all_checkpoint_paths:
p.unlink(missing_ok=True)

# save directory
self.experiment.log_artifact(checkpoint_callback.dirpath)

0 comments on commit 110b9aa

Please sign in to comment.