From 4de4f0fd65e628098605fe698d383c52fbf272df Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Mon, 10 Jul 2023 02:19:37 +0300 Subject: [PATCH] fetch: introduce --type metrics/plots Needed for https://github.com/iterative/studio/pull/6541 but is also generally useful. --- dvc/commands/data_sync.py | 12 +++++ dvc/repo/fetch.py | 2 + dvc/repo/index.py | 66 ++++++++++++++++++++++++++++ tests/unit/command/test_data_sync.py | 5 +++ 4 files changed, 85 insertions(+) diff --git a/dvc/commands/data_sync.py b/dvc/commands/data_sync.py index 117684afd8..6c2f92489c 100644 --- a/dvc/commands/data_sync.py +++ b/dvc/commands/data_sync.py @@ -92,6 +92,7 @@ def run(self): recursive=self.args.recursive, run_cache=self.args.run_cache, max_size=self.args.max_size, + types=self.args.types, ) self.log_summary({"fetched": processed_files_count}) except DvcException: @@ -328,6 +329,17 @@ def add_parser(subparsers, _parent_parser): type=int, help="Fetch data files/directories that are each below specified size (bytes).", ) + fetch_parser.add_argument( + "--type", + dest="types", + action="append", + default=[], + help=( + "Only fetch data files/directories that are of a particular " + "type (metrics, plots)." + ), + choices=["metrics", "plots"], + ) fetch_parser.set_defaults(func=CmdDataFetch) # Status diff --git a/dvc/repo/fetch.py b/dvc/repo/fetch.py index 8b22bb5816..51e8634a85 100644 --- a/dvc/repo/fetch.py +++ b/dvc/repo/fetch.py @@ -23,6 +23,7 @@ def fetch( # noqa: C901, PLR0913 run_cache=False, revs=None, max_size=None, + types=None, ) -> int: """Download data items from a cloud and imported repositories @@ -72,6 +73,7 @@ def fetch( # noqa: C901, PLR0913 with_deps=with_deps, recursive=recursive, max_size=max_size, + types=types, ) index_keys.add(idx.data_tree.hash_info.value) indexes.append(idx.data["repo"]) diff --git a/dvc/repo/index.py b/dvc/repo/index.py index bb9a704ee7..833c4496ad 100644 --- a/dvc/repo/index.py +++ b/dvc/repo/index.py @@ -374,6 +374,50 @@ def data_keys(self) -> Dict[str, Set["DataIndexKey"]]: return dict(by_workspace) + @cached_property + def metric_keys(self) -> Dict[str, Set["DataIndexKey"]]: + from collections import defaultdict + + from .metrics.show import _collect_top_level_metrics + + by_workspace: Dict[str, Set["DataIndexKey"]] = defaultdict(set) + + by_workspace["repo"] = set() + + for out in self.outs: + if not out.metric: + continue + + workspace, key = out.index_key + by_workspace[workspace].add(key) + + for path in _collect_top_level_metrics(self.repo): + key = self.repo.fs.path.relparts(path, self.repo.root_dir) + by_workspace["repo"].add(key) + + return dict(by_workspace) + + @cached_property + def plot_keys(self) -> Dict[str, Set["DataIndexKey"]]: + from collections import defaultdict + + by_workspace: Dict[str, Set["DataIndexKey"]] = defaultdict(set) + + by_workspace["repo"] = set() + + for out in self.outs: + if not out.plot: + continue + + workspace, key = out.index_key + by_workspace[workspace].add(key) + + for path in self._plot_sources: + key = self.repo.fs.path.relparts(path, self.repo.root_dir) + by_workspace["repo"].add(key) + + return dict(by_workspace) + @cached_property def data_tree(self): from dvc_data.hashfile.tree import Tree @@ -487,12 +531,31 @@ def used_objs( used[odb].update(objs) return used + def _types_filter(self, types, out): + ws, okey = out.index_key + for typ in types: + if typ == "plots": + keys = self.plot_keys + elif typ == "metrics": + keys = self.metric_keys + else: + raise ValueError(f"unsupported type {typ}") + + for key in keys.get(ws, []): + if (len(key) >= len(okey) and key[: len(okey)] == okey) or ( + len(key) < len(okey) and okey[: len(key)] == key + ): + return True + + return False + def targets_view( self, targets: Optional["TargetType"], stage_filter: Optional[Callable[["Stage"], bool]] = None, outs_filter: Optional[Callable[["Output"], bool]] = None, max_size: Optional[int] = None, + types: Optional[List[str]] = None, **kwargs: Any, ) -> "IndexView": """Return read-only view of index for the specified targets. @@ -520,6 +583,9 @@ def _outs_filter(out): if max_size and out.meta and out.meta.size and out.meta.size >= max_size: return False + if types and not self._types_filter(types, out): + return False + if outs_filter: return outs_filter(out) diff --git a/tests/unit/command/test_data_sync.py b/tests/unit/command/test_data_sync.py index 9af90c5298..13f180c88f 100644 --- a/tests/unit/command/test_data_sync.py +++ b/tests/unit/command/test_data_sync.py @@ -20,6 +20,10 @@ def test_fetch(mocker, dvc): "--run-cache", "--max-size", "10", + "--type", + "plots", + "--type", + "metrics", ] ) assert cli_args.func == CmdDataFetch @@ -40,6 +44,7 @@ def test_fetch(mocker, dvc): recursive=True, run_cache=True, max_size=10, + types=["plots", "metrics"], )