diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index 417f088afe..c49efca80d 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -101,9 +101,9 @@ def download(self, to: "Output", jobs: Optional[int] = None): return hashes: list[tuple[str, HashInfo, dict[str, Any]]] = [] - for src_path, dest_path, *rest in files: + for src_path, dest_path, maybe_info in files: try: - info = rest[0] if rest else self.fs.info(src_path) + info = maybe_info or self.fs.info(src_path) hash_info = info["dvc_info"]["entry"].hash_info dest_info = to.fs.info(dest_path) except (KeyError, AttributeError): diff --git a/dvc/fs/__init__.py b/dvc/fs/__init__.py index 4b739428c6..c89f203432 100644 --- a/dvc/fs/__init__.py +++ b/dvc/fs/__init__.py @@ -1,5 +1,6 @@ import glob -from typing import Optional, Union +from itertools import repeat +from typing import Optional from urllib.parse import urlparse from dvc.config import ConfigError as RepoConfigError @@ -47,7 +48,7 @@ def download( fs: "FileSystem", fs_path: str, to: str, jobs: Optional[int] = None -) -> list[Union[tuple[str, str], tuple[str, str, dict]]]: +) -> list[tuple[str, str, Optional[dict]]]: from dvc.scm import lfs_prefetch from .callbacks import TqdmCallback @@ -84,7 +85,7 @@ def download( cb.set_size(len(from_infos)) jobs = jobs or fs.jobs generic.copy(fs, from_infos, localfs, to_infos, callback=cb, batch_size=jobs) - return list(zip(from_infos, to_infos)) + return list(zip(from_infos, to_infos, repeat(None))) def parse_external_url(url, fs_config=None, config=None): diff --git a/dvc/fs/dvc.py b/dvc/fs/dvc.py index e5f7774736..6d1a77e74d 100644 --- a/dvc/fs/dvc.py +++ b/dvc/fs/dvc.py @@ -4,8 +4,8 @@ import os import posixpath import threading -from collections import deque -from contextlib import ExitStack, suppress +from collections import defaultdict, deque +from contextlib import ExitStack, nullcontext, suppress from glob import has_magic from typing import TYPE_CHECKING, Any, Callable, Optional, Union @@ -19,6 +19,8 @@ from .data import DataFileSystem if TYPE_CHECKING: + from contextlib import AbstractContextManager + from dvc.repo import Repo from dvc.types import DictStrAny, StrPath @@ -498,7 +500,7 @@ def get( **kwargs, ) - def _get( # noqa: C901 + def _get( # noqa: C901, PLR0912 self, rpath, lpath, @@ -507,7 +509,7 @@ def _get( # noqa: C901 maxdepth=None, batch_size=None, **kwargs, - ) -> list[Union[tuple[str, str], tuple[str, str, dict]]]: + ) -> list[tuple[str, str, Optional[dict]]]: if ( isinstance(rpath, list) or isinstance(lpath, list) @@ -531,10 +533,13 @@ def _get( # noqa: C901 if self.isfile(rpath): with callback.branched(rpath, lpath) as child: self.get_file(rpath, lpath, callback=child, **kwargs) - return [(rpath, lpath)] + return [(rpath, lpath, None)] - _files = [] + result: list[tuple[str, str, Optional[dict]]] = [] _dirs: list[str] = [] + _files: dict[FileSystem, list[tuple[str, str, Optional[dict]]]] + _files = defaultdict(list) + for root, dirs, files in self.walk(rpath, maxdepth=maxdepth, detail=True): if files: callback.set_size((callback.size or 0) + len(files)) @@ -550,32 +555,42 @@ def _get( # noqa: C901 _, dvc_fs, _ = self._get_subrepo_info(key) for name, info in files.items(): + dvc_info = info.get("dvc_info") + fs_info = info.get("fs_info") + if dvc_fs and dvc_info and not fs_info: + fs = dvc_fs + fs_path = dvc_info["name"] + else: + fs = self.repo.fs + fs_path = fs_info["name"] + src_path = f"{root}{self.sep}{name}" dest_path = f"{dest_root}{os.path.sep}{name}" - _files.append((dvc_fs, src_path, dest_path, info)) + _files[fs].append((fs_path, dest_path, dvc_info)) + result.append((src_path, dest_path, info)) os.makedirs(lpath, exist_ok=True) for d in _dirs: os.mkdir(d) - repo_fs = self.repo.fs + def get_file(arg: tuple[FileSystem, tuple[str, str, Optional[dict]]]): + fs, (src, dest, info) = arg + kw = kwargs + if isinstance(fs, DataFileSystem): + kw = kw | {"info": info} + return fs.get_file(src, dest, callback=callback, **kw) - def _get_file(arg): - dvc_fs, src, dest, info = arg - dvc_info = info.get("dvc_info") - fs_info = info.get("fs_info") - if dvc_fs and dvc_info and not fs_info: - dvc_path = dvc_info["name"] - dvc_fs.get_file( - dvc_path, dest, callback=callback, info=dvc_info, **kwargs - ) - else: - fs_path = fs_info["name"] - repo_fs.get_file(fs_path, dest, callback=callback, **kwargs) - return src, dest, info + if batch_size == 1: + ctx: AbstractContextManager = nullcontext() + map_fn: Callable = map + else: + ctx = ThreadPoolExecutor(max_workers=batch_size) + map_fn = ctx.imap_unordered - with ThreadPoolExecutor(max_workers=batch_size) as executor: - return list(executor.imap_unordered(_get_file, _files)) + with ctx: + it = ((fs, f) for fs, files in _files.items() for f in files) + deque(map_fn(get_file, it), maxlen=0) + return result def get_file(self, rpath, lpath, **kwargs): dvc_info = kwargs.pop("info", {}).pop("dvc_info", None) @@ -670,7 +685,7 @@ def _get( recursive: bool = False, batch_size: Optional[int] = None, **kwargs, - ) -> list[Union[tuple[str, str], tuple[str, str, dict]]]: + ) -> list[tuple[str, str, Optional[dict]]]: # FileSystem.get is non-recursive by default if arguments are lists # otherwise, it's recursive. recursive = not (isinstance(from_info, list) and isinstance(to_info, list))