Skip to content

Commit

Permalink
log_artifact: add cache option, only write to dvc.yaml if metadata ex…
Browse files Browse the repository at this point in the history
…ists
  • Loading branch information
dberenbaum authored and dberenbaum committed Jul 10, 2023
1 parent 6ddc634 commit dfff9c4
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 29 deletions.
40 changes: 22 additions & 18 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,11 @@ def log_artifact(
path: StrPath,
type: Optional[str] = None, # noqa: A002
name: Optional[str] = None,
desc: Optional[str] = None, # noqa: ARG002
labels: Optional[List[str]] = None, # noqa: ARG002
meta: Optional[Dict[str, Any]] = None, # noqa: ARG002
desc: Optional[str] = None,
labels: Optional[List[str]] = None,
meta: Optional[Dict[str, Any]] = None,
copy: bool = False,
cache: bool = True,
):
"""Tracks a local file or directory with DVC"""
if not isinstance(path, (str, Path)):
Expand All @@ -428,21 +429,24 @@ def log_artifact(
if copy:
path = clean_and_copy_into(path, self.artifacts_dir)

self.cache(path)

name = name or Path(path).stem
if name_is_compatible(name):
self._artifacts[name] = {
k: v
for k, v in locals().items()
if k in ("path", "type", "desc", "labels", "meta") and v is not None
}
else:
logger.warning(
"Can't use '%s' as artifact name (ID)."
" It will not be included in the `artifacts` section.",
name,
)
if cache:
self.cache(path)

if any((type, name, desc, labels, meta)):
name = name or Path(path).stem
if name_is_compatible(name):
self._artifacts[name] = {
k: v
for k, v in locals().items()
if k in ("path", "type", "desc", "labels", "meta")
and v is not None
}
else:
logger.warning(
"Can't use '%s' as artifact name (ID)."
" It will not be included in the `artifacts` section.",
name,
)

def cache(self, path):
try:
Expand Down
26 changes: 15 additions & 11 deletions tests/test_log_artifact.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import shutil
from pathlib import Path

import pytest

from dvclive import Live
from dvclive.serialize import load_yaml


def test_log_artifact(tmp_dir, dvc_repo):
@pytest.mark.parametrize("cache", [True, False])
def test_log_artifact(tmp_dir, dvc_repo, cache):
data = tmp_dir / "data"
data.touch()
with Live() as live:
live.log_artifact("data")
assert data.with_suffix(".dvc").exists()
live.log_artifact("data", cache=cache)
assert data.with_suffix(".dvc").exists() is cache
assert load_yaml(live.dvc_file) == {}


def test_log_artifact_on_existing_dvc_file(tmp_dir, dvc_repo):
Expand Down Expand Up @@ -78,14 +82,14 @@ def test_log_artifact_copy(tmp_dir, dvc_repo):
(tmp_dir / "model.pth").touch()

with Live() as live:
live.log_artifact("model.pth", copy=True)
live.log_artifact("model.pth", type="model", copy=True)

artifacts_dir = Path(live.artifacts_dir)
assert (artifacts_dir / "model.pth").exists()
assert (artifacts_dir / "model.pth.dvc").exists()

assert load_yaml(live.dvc_file) == {
"artifacts": {"model": {"path": "artifacts/model.pth"}}
"artifacts": {"model": {"path": "artifacts/model.pth", "type": "model"}}
}


Expand All @@ -97,15 +101,15 @@ def test_log_artifact_copy_overwrite(tmp_dir, dvc_repo):
# testing with symlink cache to make sure that DVC protected mode
# does not prevent the overwrite
live._dvc_repo.cache.local.cache_types = ["symlink"]
live.log_artifact("model.pth", copy=True)
live.log_artifact("model.pth", type="model", copy=True)
assert (artifacts_dir / "model.pth").is_symlink()
live.log_artifact("model.pth", copy=True)
live.log_artifact("model.pth", type="model", copy=True)

assert (artifacts_dir / "model.pth").exists()
assert (artifacts_dir / "model.pth.dvc").exists()

assert load_yaml(live.dvc_file) == {
"artifacts": {"model": {"path": "artifacts/model.pth"}}
"artifacts": {"model": {"path": "artifacts/model.pth", "type": "model"}}
}


Expand All @@ -119,14 +123,14 @@ def test_log_artifact_copy_directory_overwrite(tmp_dir, dvc_repo):
# testing with symlink cache to make sure that DVC protected mode
# does not prevent the overwrite
live._dvc_repo.cache.local.cache_types = ["symlink"]
live.log_artifact(model_path, copy=True)
live.log_artifact(model_path, type="model", copy=True)
assert (artifacts_dir / "weights" / "model-epoch-1.pth").is_symlink()

shutil.rmtree(model_path)
model_path.mkdir()
(tmp_dir / "weights" / "model-epoch-10.pth").write_text("Model weights")
(tmp_dir / "weights" / "best.pth").write_text("Best model weights")
live.log_artifact(model_path, copy=True)
live.log_artifact(model_path, type="model", copy=True)

assert (artifacts_dir / "weights").exists()
assert (artifacts_dir / "weights" / "best.pth").is_symlink()
Expand All @@ -135,7 +139,7 @@ def test_log_artifact_copy_directory_overwrite(tmp_dir, dvc_repo):
assert len(list((artifacts_dir / "weights").iterdir())) == 2

assert load_yaml(live.dvc_file) == {
"artifacts": {"weights": {"path": "artifacts/weights"}}
"artifacts": {"weights": {"path": "artifacts/weights", "type": "model"}}
}


Expand Down

0 comments on commit dfff9c4

Please sign in to comment.