diff --git a/dvc/api/__init__.py b/dvc/api/__init__.py index b3ba6bebe0..c82de3b24f 100644 --- a/dvc/api/__init__.py +++ b/dvc/api/__init__.py @@ -1,5 +1,6 @@ from dvc.fs.dvc import _DVCFileSystem as DVCFileSystem +from .artifacts import artifacts_show from .data import open # pylint: disable=redefined-builtin from .data import get_url, read from .experiments import exp_save, exp_show @@ -10,6 +11,7 @@ "all_branches", "all_commits", "all_tags", + "artifacts_show", "exp_save", "exp_show", "get_url", diff --git a/dvc/api/artifacts.py b/dvc/api/artifacts.py new file mode 100644 index 0000000000..33971bc602 --- /dev/null +++ b/dvc/api/artifacts.py @@ -0,0 +1,47 @@ +from typing import Any, Dict, Optional + +from dvc.repo import Repo + + +def artifacts_show( + name: str, + version: Optional[str] = None, + stage: Optional[str] = None, + repo: Optional[str] = None, +) -> Dict[str, str]: + """ + Return path and Git revision for an artifact in a DVC project. + + The resulting path and revision can be used in conjunction with other dvc.api + calls to open and read the artifact. + + Args: + name (str): name of the artifact to open. + version (str, optional): version of the artifact to open. Defaults to + the latest version. + stage (str, optional): name of the model registry stage. + repo: (str, optional): path or URL for the DVC repo. + + Returns: + Dictionary of the form: + { + "rev": ..., + "path": ..., + } + + Raises: + dvc.exceptions.ArtifactNotFoundError: The specified artifact was not found in + the repo. + """ + if version and stage: + raise ValueError("Artifact version and stage are mutually exclusive.") + + repo_kwargs: Dict[str, Any] = { + "subrepos": True, + "uninitialized": True, + } + with Repo.open(repo, **repo_kwargs) as _repo: + rev = _repo.artifacts.get_rev(name, version=version, stage=stage) + with _repo.switch(rev): + path = _repo.artifacts.get_path(name) + return {"rev": rev, "path": path} diff --git a/dvc/cli/parser.py b/dvc/cli/parser.py index 10e1cce540..a71b9fb946 100644 --- a/dvc/cli/parser.py +++ b/dvc/cli/parser.py @@ -7,6 +7,7 @@ from dvc import __version__ from dvc.commands import ( add, + artifacts, cache, check_ignore, checkout, @@ -89,6 +90,7 @@ experiments, check_ignore, data, + artifacts, ] diff --git a/dvc/commands/artifacts.py b/dvc/commands/artifacts.py new file mode 100644 index 0000000000..a2a8a564fd --- /dev/null +++ b/dvc/commands/artifacts.py @@ -0,0 +1,133 @@ +import argparse +import logging + +from dvc.cli import completion +from dvc.cli.command import CmdBaseNoRepo +from dvc.cli.utils import DictAction, append_doc_link, fix_subparsers +from dvc.exceptions import DvcException + +logger = logging.getLogger(__name__) + + +class CmdArtifactsGet(CmdBaseNoRepo): + def run(self): + from dvc.repo.artifacts import Artifacts + from dvc.scm import CloneError + from dvc.ui import ui + + try: + count, out = Artifacts.get( + self.args.url, + name=self.args.name, + version=self.args.rev, + stage=self.args.stage, + force=self.args.force, + config=self.args.config, + remote=self.args.remote, + remote_config=self.args.remote_config, + out=self.args.out, + ) + ui.write(f"Downloaded {count} file(s) to '{out}'") + return 0 + except CloneError: + logger.exception("failed to get '%s'", self.args.name) + return 1 + except DvcException: + logger.exception( + "failed to get '%s' from '%s'", self.args.name, self.args.url + ) + return 1 + + +def add_parser(subparsers, parent_parser): + ARTIFACTS_HELP = "DVC model registry artifact commands." + + artifacts_parser = subparsers.add_parser( + "artifacts", + parents=[parent_parser], + description=append_doc_link(ARTIFACTS_HELP, "artifacts"), + help=ARTIFACTS_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + artifacts_subparsers = artifacts_parser.add_subparsers( + dest="cmd", + help="Use `dvc artifacts CMD --help` to display command-specific help.", + ) + fix_subparsers(artifacts_subparsers) + + ARTIFACTS_GET_HELP = "Download an artifact from a DVC project." + get_parser = artifacts_subparsers.add_parser( + "get", + parents=[parent_parser], + description=append_doc_link(ARTIFACTS_GET_HELP, "artifacts/get"), + help=ARTIFACTS_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + get_parser.add_argument("url", help="Location of DVC repository to download from") + get_parser.add_argument( + "name", help="Name of artifact in the repository" + ).complete = completion.FILE + get_parser.add_argument( + "--rev", + nargs="?", + help="Artifact version", + metavar="", + ) + get_parser.add_argument( + "--stage", + nargs="?", + help="Artifact stage", + metavar="", + ) + get_parser.add_argument( + "-o", + "--out", + nargs="?", + help="Destination path to download artifact to", + metavar="", + ).complete = completion.DIR + get_parser.add_argument( + "-j", + "--jobs", + type=int, + help=( + "Number of jobs to run simultaneously. " + "The default value is 4 * cpu_count(). " + ), + metavar="", + ) + get_parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="Override local file or folder if exists.", + ) + get_parser.add_argument( + "--config", + type=str, + help=( + "Path to a config file that will be merged with the config " + "in the target repository." + ), + ) + get_parser.add_argument( + "--remote", + type=str, + help=( + "Remote name to set as a default in the target repository " + "(only applicable when downloading from DVC remote)." + ), + ) + get_parser.add_argument( + "--remote-config", + type=str, + nargs="*", + action=DictAction, + help=( + "Remote config options to merge with a remote's config (default or one " + "specified by '--remote') in the target repository (only applicable " + "when downloading from DVC remote)." + ), + ) + get_parser.set_defaults(func=CmdArtifactsGet) diff --git a/dvc/exceptions.py b/dvc/exceptions.py index ba3f8467e0..dbdefa3e52 100644 --- a/dvc/exceptions.py +++ b/dvc/exceptions.py @@ -1,6 +1,6 @@ """Exceptions raised by the dvc.""" import errno -from typing import TYPE_CHECKING, Dict, List, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set from dvc.utils import format_link @@ -336,3 +336,24 @@ def __init__(self, fs_paths): class PrettyDvcException(DvcException): def __pretty_exc__(self, **kwargs): """Print prettier exception message.""" + + +class ArtifactNotFoundError(DvcException): + """Thrown if an artifact is not found in the DVC repo. + + Args: + name (str): artifact name. + """ + + def __init__( + self, + name: str, + version: Optional[str] = None, + stage: Optional[str] = None, + ): + self.name = name + self.version = version + self.stage = stage + + desc = f" @ {stage or version}" if (stage or version) else "" + super().__init__(f"Unable to find artifact '{name}{desc}'") diff --git a/dvc/fs/__init__.py b/dvc/fs/__init__.py index dcc90b012f..3c948ad948 100644 --- a/dvc/fs/__init__.py +++ b/dvc/fs/__init__.py @@ -49,7 +49,9 @@ # pylint: enable=unused-import -def download(fs: "FileSystem", fs_path: str, to: str, jobs: Optional[int] = None): +def download( + fs: "FileSystem", fs_path: str, to: str, jobs: Optional[int] = None +) -> int: with Callback.as_tqdm_callback( desc=f"Downloading {fs.path.name(fs_path)}", unit="files", @@ -63,7 +65,8 @@ def download(fs: "FileSystem", fs_path: str, to: str, jobs: Optional[int] = None if not path.endswith(fs.path.flavour.sep) ] if not from_infos: - return localfs.makedirs(to, exist_ok=True) + localfs.makedirs(to, exist_ok=True) + return 0 to_infos = [ localfs.path.join(to, *fs.path.relparts(info, fs_path)) for info in from_infos @@ -82,6 +85,7 @@ def download(fs: "FileSystem", fs_path: str, to: str, jobs: Optional[int] = None callback=cb, batch_size=jobs, ) + return len(to_infos) def parse_external_url(url, config=None): diff --git a/dvc/repo/artifacts.py b/dvc/repo/artifacts.py index 0615486c29..f42ce4f051 100644 --- a/dvc/repo/artifacts.py +++ b/dvc/repo/artifacts.py @@ -1,42 +1,46 @@ import logging -import re +import os +import posixpath from pathlib import Path -from typing import Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from dvc.annotations import Artifact from dvc.dvcfile import PROJECT_FILE -from dvc.exceptions import InvalidArgumentError -from dvc.repo import Repo -from dvc.utils import relpath +from dvc.exceptions import ( + ArtifactNotFoundError, + DvcException, + FileExistsLocallyError, + InvalidArgumentError, +) +from dvc.utils import relpath, resolve_output +from dvc.utils.objects import cached_property from dvc.utils.serialize import modify_yaml -logger = logging.getLogger(__name__) - - -# Constants are taken from GTO. -# When we make it a dependency, we can import them instead -SEPARATOR_IN_NAME = ":" -DIRNAME = r"[a-z0-9-_./]+" -NAME = r"[a-z0-9]([a-z0-9-/]*[a-z0-9])?" -NAME_RE = re.compile(f"^{NAME}$") -FULLNAME = f"((?P{DIRNAME}){SEPARATOR_IN_NAME})?(?P{NAME})" -FULLNAME_RE = re.compile(f"^{FULLNAME}$") +if TYPE_CHECKING: + from gto.tag import Tag as GTOTag + from scmrepo.git import GitTag + from dvc.repo import Repo + from dvc.scm import Git -def name_is_compatible(name: str) -> bool: - return bool(NAME_RE.search(name)) +logger = logging.getLogger(__name__) def check_name_format(name: str) -> None: - if not name_is_compatible(name): + from gto.constants import assert_name_is_valid + from gto.exceptions import ValidationError + + try: + assert_name_is_valid(name) + except ValidationError as exc: raise InvalidArgumentError( f"Can't use '{name}' as artifact name (ID)." - " You can use letters and numbers, and use '-' as separator" - " (but not at the start or end)." - ) + ) from exc def check_for_nested_dvc_repo(dvcfile: Path): + from dvc.repo import Repo + if dvcfile.is_absolute(): raise InvalidArgumentError("Use relative path to dvc.yaml.") path = dvcfile.parent @@ -48,11 +52,36 @@ def check_for_nested_dvc_repo(dvcfile: Path): path = path.parent +def _reformat_name(name: str) -> str: + from gto.constants import SEPARATOR_IN_NAME, fullname_re + + # NOTE: DVC accepts names like + # path/to/dvc.yaml:artifact_name + # but Studio/GTO tags are generated with + # path/to:artifact_name + m = fullname_re.match(name) + if m and m.group("dirname"): + group = m.group("dirname").rstrip(SEPARATOR_IN_NAME) + dirname, basename = posixpath.split(group) + if basename == PROJECT_FILE: + name = f"{dirname}{SEPARATOR_IN_NAME}{m.group('name')}" + return name + + class Artifacts: def __init__(self, repo: "Repo") -> None: self.repo = repo + @cached_property + def scm(self) -> Optional["Git"]: + from dvc.scm import Git + + if isinstance(self.repo.scm, Git): + return self.repo.scm + return None + def read(self) -> Dict[str, Dict[str, Artifact]]: + """Read artifacts from dvc.yaml.""" artifacts: Dict[str, Dict[str, Artifact]] = {} for ( dvcfile, @@ -69,6 +98,7 @@ def read(self) -> Dict[str, Dict[str, Artifact]]: return artifacts def add(self, name: str, artifact: Artifact, dvcfile: Optional[str] = None): + """Add artifact to dvc.yaml.""" with self.repo.scm_context(quiet=True): check_name_format(name) dvcyaml = Path(dvcfile or PROJECT_FILE) @@ -85,3 +115,219 @@ def add(self, name: str, artifact: Artifact, dvcfile: Optional[str] = None): self.repo.scm_context.track_file(dvcfile) return artifacts.get(name) + + def get_rev( + self, name: str, version: Optional[str] = None, stage: Optional[str] = None + ): + """Return revision containing the given artifact.""" + from gto.base import sort_versions + from gto.tag import find, parse_tag + + assert not (version and stage) + name = _reformat_name(name) + tags: List["GitTag"] = find( + name=name, version=version, stage=stage, scm=self.scm + ) + if not tags: + raise ArtifactNotFoundError(name, version=version, stage=stage) + if version or stage: + return tags[-1].target + gto_tags: List["GTOTag"] = sort_versions(parse_tag(tag) for tag in tags) + return gto_tags[0].tag.target + + def get_path(self, name: str): + """Return repo fspath for the given artifact.""" + from gto.constants import SEPARATOR_IN_NAME, fullname_re + + name = _reformat_name(name) + m = fullname_re.match(name) + if not m: + raise ArtifactNotFoundError(name) + dirname = m.group("dirname") + if dirname: + dirname = dirname.rstrip(SEPARATOR_IN_NAME) + dvcyaml = os.path.join(dirname, PROJECT_FILE) if dirname else PROJECT_FILE + artifact_name = m.group("name") + try: + artifact = self.read()[dvcyaml][artifact_name] + except KeyError as exc: + raise ArtifactNotFoundError(name) from exc + return os.path.join(dirname, artifact.path) if dirname else artifact.path + + def download( + self, + name: str, + version: Optional[str] = None, + stage: Optional[str] = None, + out: Optional[str] = None, + force: bool = False, + jobs: Optional[int] = None, + ) -> Tuple[int, str]: + """Download the specified artifact.""" + from dvc.fs import download as fs_download + + logger.debug("Trying to download artifact '%s' via DVC", name) + rev = self.get_rev(name, version=version, stage=stage) + with self.repo.switch(rev): + path = self.get_path(name) + out = resolve_output(path, out, force=force) + fs = self.repo.dvcfs + fs_path = fs.from_os_path(path) + count = fs_download( + fs, + fs_path, + os.path.abspath(out), + jobs=jobs, + ) + return count, out + + @staticmethod + def _download_studio( + repo_url: str, + name: str, + version: Optional[str] = None, + stage: Optional[str] = None, + out: Optional[str] = None, + force: bool = False, + jobs: Optional[int] = None, + studio_config: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Tuple[int, str]: + from dvc_studio_client.model_registry import get_download_uris + + from dvc.fs import Callback, HTTPFileSystem, generic, localfs + + logger.debug("Trying to download artifact '%s' via studio", name) + out = out or os.getcwd() + to_infos: List[str] = [] + from_infos: List[str] = [] + if studio_config is None: + studio_config = {} + studio_config["repo_url"] = repo_url + try: + kwargs["dvc_studio_config"] = studio_config + for path, url in get_download_uris( + repo_url, + name, + version=version, + stage=stage, + **kwargs, + ).items(): + to_info = localfs.path.join(out, path) + if localfs.exists(to_info) and not force: + hint = "\nTo override it, re-run with '--force'." + raise FileExistsLocallyError( # noqa: TRY301 + relpath(to_info), hint=hint + ) + to_infos.append(to_info) + from_infos.append(url) + except DvcException: + raise + except Exception as exc: # noqa: BLE001, pylint: disable=broad-except + raise DvcException( + f"Failed to download artifact '{name}' via Studio" + ) from exc + fs = HTTPFileSystem() + jobs = jobs or fs.jobs + with Callback.as_tqdm_callback( + desc=f"Downloading '{name}' from '{repo_url}'", + unit="files", + ) as cb: + cb.set_size(len(from_infos)) + generic.copy( + fs, from_infos, localfs, to_infos, callback=cb, batch_size=jobs + ) + + return len(to_infos), relpath(localfs.path.commonpath(to_infos)) + + @classmethod + def get( # noqa: C901, PLR0913 + cls, + url: str, + name: str, + version: Optional[str] = None, + stage: Optional[str] = None, + config: Optional[Union[str, Dict[str, Any]]] = None, + remote: Optional[str] = None, + remote_config: Optional[Union[str, Dict[str, Any]]] = None, + out: Optional[str] = None, + force: bool = False, + jobs: Optional[int] = None, + ): + from dvc.config import Config + from dvc.repo import Repo + + if version and stage: + raise InvalidArgumentError( + "Artifact version and stage are mutually exclusive." + ) + + # NOTE: We try to download the artifact up to three times + # 1. via studio with studio config loaded from environment + # 2. via studio with studio config loaded from DVC repo 'studio' + # section + environment + # 3. via DVC remote + + name = _reformat_name(name) + saved_exc: Optional[Exception] = None + try: + logger.trace("Trying studio-only config") # type: ignore[attr-defined] + return cls._download_studio( + url, + name, + version=version, + stage=stage, + out=out, + force=force, + jobs=jobs, + ) + except FileExistsLocallyError: + raise + except Exception as exc: # noqa: BLE001, pylint: disable=broad-except + saved_exc = exc + + if config and not isinstance(config, dict): + config = Config.load_file(config) + with Repo.open( + url=url, + subrepos=True, + uninitialized=True, + config=config, + remote=remote, + remote_config=remote_config, + ) as repo: + logger.trace("Trying repo [studio] config") # type: ignore[attr-defined] + studio_config = dict(repo.config.get("studio")) + try: + return cls._download_studio( + url, + name, + version=version, + stage=stage, + out=out, + force=force, + jobs=jobs, + dvc_studio_config=studio_config, + ) + except FileExistsLocallyError: + raise + except Exception as exc: # noqa: BLE001, pylint: disable=broad-except + saved_exc = exc + + try: + return repo.artifacts.download( + name, + version=version, + stage=stage, + out=out, + force=force, + jobs=jobs, + ) + except FileExistsLocallyError: + raise + except Exception as exc: # noqa: BLE001, pylint: disable=broad-except + if saved_exc: + logger.exception(str(saved_exc), exc_info=saved_exc.__cause__) + raise DvcException( + f"Failed to download artifact '{name}' via DVC remote" + ) from exc diff --git a/dvc/utils/__init__.py b/dvc/utils/__init__.py index 668f2b8e53..42f5619f46 100644 --- a/dvc/utils/__init__.py +++ b/dvc/utils/__init__.py @@ -248,7 +248,7 @@ def env2bool(var, undefined=False): return bool(re.search("1|y|yes|true", var, flags=re.I)) -def resolve_output(inp, out, force=False): +def resolve_output(inp: str, out: Optional[str], force=False) -> str: from urllib.parse import urlparse from dvc.exceptions import FileExistsLocallyError diff --git a/pyproject.toml b/pyproject.toml index 3fd0de861a..6926171f73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,13 +37,14 @@ dependencies = [ "dvc-data>=2.16.3,<2.17.0", "dvc-http>=2.29.0", "dvc-render>=0.3.1,<1", - "dvc-studio-client>=0.9.2,<1", + "dvc-studio-client>=0.13.0,<1", "dvc-task>=0.3.0,<1", "flatten_dict<1,>=0.4.1", # https://github.com/iterative/dvc/issues/9654 "flufl.lock>=5,<8", "funcy>=1.14", "grandalf<1,>=0.7", + "gto>=1.3.0,<2", "hydra-core>=1.1", "iterative-telemetry>=0.0.7", "networkx>=2.5", diff --git a/tests/func/artifacts/test_artifacts.py b/tests/func/artifacts/test_artifacts.py index cf6c422036..1e065d7c8e 100644 --- a/tests/func/artifacts/test_artifacts.py +++ b/tests/func/artifacts/test_artifacts.py @@ -5,8 +5,8 @@ import pytest from dvc.annotations import Artifact -from dvc.exceptions import InvalidArgumentError -from dvc.repo.artifacts import name_is_compatible +from dvc.exceptions import ArtifactNotFoundError, InvalidArgumentError +from dvc.repo.artifacts import check_name_format from dvc.utils.strictyaml import YAMLSyntaxError, YAMLValidationError dvcyaml = { @@ -43,7 +43,7 @@ def test_artifacts_read_subdir(tmp_dir, dvc): def test_artifacts_read_bad_name(tmp_dir, dvc, caplog): bad_name_dvcyaml = deepcopy(dvcyaml) - bad_name_dvcyaml["artifacts"]["bad_name"] = {"type": "model", "path": "bad.pkl"} + bad_name_dvcyaml["artifacts"]["Bad_name"] = {"type": "model", "path": "bad.pkl"} (tmp_dir / "dvc.yaml").dump(bad_name_dvcyaml) @@ -54,7 +54,7 @@ def test_artifacts_read_bad_name(tmp_dir, dvc, caplog): with caplog.at_level(logging.WARNING): assert tmp_dir.dvc.artifacts.read() == {"dvc.yaml": artifacts} - assert "Can't use 'bad_name' as artifact name (ID)" in caplog.text + assert "Can't use 'Bad_name' as artifact name (ID)" in caplog.text def test_artifacts_add_subdir(tmp_dir, dvc): @@ -160,7 +160,7 @@ def test_artifacts_read_fails_on_id_duplication(tmp_dir, dvc): ], ) def test_name_is_compatible(name): - assert name_is_compatible(name) + check_name_format(name) @pytest.mark.parametrize( @@ -172,7 +172,6 @@ def test_name_is_compatible(name): "###", "@@@", "a model", - "a_model", "-model", "model-", "model@1", @@ -181,4 +180,34 @@ def test_name_is_compatible(name): ], ) def test_name_is_compatible_fails(name): - assert not name_is_compatible(name) + with pytest.raises(InvalidArgumentError): + check_name_format(name) + + +def test_get_rev(tmp_dir, dvc, scm): + scm.tag("myart@v1.0.0#1", annotated=True, message="foo") + scm.tag("subdir=myart@v2.0.0#1", annotated=True, message="foo") + scm.tag("myart#dev#1", annotated=True, message="foo") + rev = scm.get_rev() + + assert dvc.artifacts.get_rev("myart") == rev + assert dvc.artifacts.get_rev("myart", version="v1.0.0") == rev + assert dvc.artifacts.get_rev("subdir:myart", version="v2.0.0") == rev + assert dvc.artifacts.get_rev("subdir/dvc.yaml:myart", version="v2.0.0") == rev + with pytest.raises(ArtifactNotFoundError): + dvc.artifacts.get_rev("myart", version="v3.0.0") + with pytest.raises(ArtifactNotFoundError): + dvc.artifacts.get_rev("myart", stage="prod") + + +def test_get_path(tmp_dir, dvc): + (tmp_dir / "dvc.yaml").dump(dvcyaml) + subdir = tmp_dir / "subdir" + subdir.mkdir() + (subdir / "dvc.yaml").dump(dvcyaml) + + assert dvc.artifacts.get_path("myart") == "myart.pkl" + assert dvc.artifacts.get_path("subdir:myart") == os.path.join("subdir", "myart.pkl") + assert dvc.artifacts.get_path("subdir/dvc.yaml:myart") == os.path.join( + "subdir", "myart.pkl" + )