diff --git a/dvc/api/data.py b/dvc/api/data.py index 51c5437b77..f5f8a6ed03 100644 --- a/dvc/api/data.py +++ b/dvc/api/data.py @@ -45,22 +45,13 @@ def get_url(path, repo=None, rev=None, remote=None): with Repo.open( repo, rev=rev, subrepos=True, uninitialized=True, **repo_kwargs ) as _repo: - with _wrap_exceptions(_repo, path): - fs_path = _repo.dvcfs.from_os_path(path) - fs = _repo.dvcfs.fs - # pylint: disable-next=protected-access - key = fs._get_key_from_relative(fs_path) - # pylint: disable-next=protected-access - subrepo, _, subkey = fs._get_subrepo_info(key) - index = subrepo.index.data["repo"] - with reraise(KeyError, OutputNotFoundError(path, repo)): - entry = index[subkey] - with reraise( - (StorageKeyError, ValueError), - NoRemoteError(f"no remote specified in {_repo}"), - ): - remote_fs, remote_path = index.storage_map.get_remote(entry) - return remote_fs.unstrip_protocol(remote_path) + index, entry = _repo.get_data_index_entry(path) + with reraise( + (StorageKeyError, ValueError), + NoRemoteError(f"no remote specified in {_repo}"), + ): + remote_fs, remote_path = index.storage_map.get_remote(entry) + return remote_fs.unstrip_protocol(remote_path) class _OpenContextManager(GCM): diff --git a/dvc/fs/dvc.py b/dvc/fs/dvc.py index 8117b019f2..4ba009928c 100644 --- a/dvc/fs/dvc.py +++ b/dvc/fs/dvc.py @@ -412,7 +412,7 @@ def _prepare_credentials(self, **config) -> Dict[str, Any]: @functools.cached_property # pylint: disable-next=invalid-overridden-method - def fs(self) -> "DVCFileSystem": + def fs(self) -> "_DVCFileSystem": return _DVCFileSystem(**self.fs_args) def isdvc(self, path, **kwargs) -> bool: diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index 544fd1f696..66de6e52f2 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -29,7 +29,7 @@ from dvc.stage import Stage from dvc.types import DictStrAny from dvc_data.hashfile.state import StateBase - from dvc_data.index import DataIndex + from dvc_data.index import DataIndex, DataIndexEntry from .experiments import Experiments from .index import Index @@ -285,7 +285,7 @@ def check_graph( new.check_graph() @staticmethod - def open(url, *args, **kwargs): # noqa: A003 + def open(url: Optional[str], *args, **kwargs) -> "Repo": # noqa: A003 from .open_repo import open_repo return open_repo(url, *args, **kwargs) @@ -375,6 +375,28 @@ def drop_data_index(self) -> None: self.data_index.commit() self._reset() + def get_data_index_entry( + self, + path: str, + workspace: str = "repo", + ) -> Tuple["DataIndex", "DataIndexEntry"]: + if self.subrepos: + fs_path = self.dvcfs.from_os_path(path) + fs = self.dvcfs.fs + # pylint: disable-next=protected-access + key = fs._get_key_from_relative(fs_path) + # pylint: disable-next=protected-access + subrepo, _, key = fs._get_subrepo_info(key) + index = subrepo.index.data[workspace] + else: + index = self.index.data[workspace] + key = self.fs.path.relparts(path, self.root_dir) + + try: + return index, index[key] + except KeyError as exc: + raise OutputNotFoundError(path, self) from exc + def __repr__(self): return f"{self.__class__.__name__}: '{self.root_dir}'" diff --git a/dvc/repo/open_repo.py b/dvc/repo/open_repo.py index 8a09011661..1b5e0bbeb8 100644 --- a/dvc/repo/open_repo.py +++ b/dvc/repo/open_repo.py @@ -2,8 +2,7 @@ import os import tempfile import threading -from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Iterator, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Tuple from funcy import retry, wrap_with @@ -18,16 +17,14 @@ logger = logging.getLogger(__name__) -@contextmanager @map_scm_exception() def _external_repo( url, rev: Optional[str] = None, - for_write: bool = False, **kwargs, -) -> Iterator["Repo"]: +) -> "Repo": logger.debug("Creating external repo %s@%s", url, rev) - path = _cached_clone(url, rev, for_write=for_write) + path = _cached_clone(url, rev) # Local HEAD points to the tip of whatever branch we first cloned from # (which may not be the default branch), use origin/HEAD here to get # the tip of the default branch @@ -38,11 +35,6 @@ def _external_repo( config.update(kwargs.pop("config", None) or {}) main_root = "/" - if for_write: - # we already checked out needed revision - rev = None - main_root = path - repo_kwargs = dict( root_dir=path, url=url, @@ -52,14 +44,7 @@ def _external_repo( **kwargs, ) - repo = Repo(**repo_kwargs) - - try: - yield repo - finally: - repo.close() - if for_write: - _remove(path) + return Repo(**repo_kwargs) def open_repo(url, *args, **kwargs): @@ -140,20 +125,19 @@ def _get_remote_config(url): repo.close() -def _cached_clone(url, rev, for_write=False): +def _cached_clone(url, rev): """Clone an external git repo to a temporary directory. Returns the path to a local temporary directory with the specified - revision checked out. If for_write is set prevents reusing this dir via - cache. + revision checked out. """ from shutil import copytree # even if we have already cloned this repo, we may need to # fetch/fast-forward to get specified rev - clone_path, shallow = _clone_default_branch(url, rev, for_write=for_write) + clone_path, shallow = _clone_default_branch(url, rev) - if not for_write and (url) in CLONES: + if url in CLONES: return CLONES[url][0] # Copy to a new dir to keep the clone clean @@ -161,16 +145,12 @@ def _cached_clone(url, rev, for_write=False): logger.debug("erepo: making a copy of %s clone", url) copytree(clone_path, repo_path) - # Check out the specified revision - if for_write: - _git_checkout(repo_path, rev) - else: - CLONES[url] = (repo_path, shallow) + CLONES[url] = (repo_path, shallow) return repo_path @wrap_with(threading.Lock()) -def _clone_default_branch(url, rev, for_write=False): # noqa: C901, PLR0912 +def _clone_default_branch(url, rev): # noqa: C901, PLR0912 """Get or create a clean clone of the url. The cloned is reactualized with git pull unless rev is a known sha. @@ -204,7 +184,7 @@ def _clone_default_branch(url, rev, for_write=False): # noqa: C901, PLR0912 logger.debug("erepo: git clone '%s' to a temporary dir", url) clone_path = tempfile.mkdtemp("dvc-clone") - if not for_write and rev and not Git.is_sha(rev): + if rev and not Git.is_sha(rev): # If rev is a tag or branch name try shallow clone first try: @@ -249,17 +229,6 @@ def _merge_upstream(git: "Git"): pass -def _git_checkout(repo_path, rev): - from dvc.scm import Git - - logger.debug("erepo: git checkout %s@%s", repo_path, rev) - git = Git(repo_path) - try: - git.checkout(rev) - finally: - git.close() - - def _remove(path): from dvc.utils.fs import remove