Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

repo: add helper method for doing data index lookup in subrepos #9708

Merged
merged 3 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions dvc/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,13 @@ def get_url(path, repo=None, rev=None, remote=None):
with Repo.open(
repo, rev=rev, subrepos=True, uninitialized=True, **repo_kwargs
) as _repo:
with _wrap_exceptions(_repo, path):
fs_path = _repo.dvcfs.from_os_path(path)
fs = _repo.dvcfs.fs
# pylint: disable-next=protected-access
key = fs._get_key_from_relative(fs_path)
# pylint: disable-next=protected-access
subrepo, _, subkey = fs._get_subrepo_info(key)
index = subrepo.index.data["repo"]
with reraise(KeyError, OutputNotFoundError(path, repo)):
entry = index[subkey]
with reraise(
(StorageKeyError, ValueError),
NoRemoteError(f"no remote specified in {_repo}"),
):
remote_fs, remote_path = index.storage_map.get_remote(entry)
return remote_fs.unstrip_protocol(remote_path)
index, entry = _repo.get_data_index_entry(path)
with reraise(
(StorageKeyError, ValueError),
NoRemoteError(f"no remote specified in {_repo}"),
):
remote_fs, remote_path = index.storage_map.get_remote(entry)
return remote_fs.unstrip_protocol(remote_path)


class _OpenContextManager(GCM):
Expand Down
2 changes: 1 addition & 1 deletion dvc/fs/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def _prepare_credentials(self, **config) -> Dict[str, Any]:

@functools.cached_property
# pylint: disable-next=invalid-overridden-method
def fs(self) -> "DVCFileSystem":
def fs(self) -> "_DVCFileSystem":
return _DVCFileSystem(**self.fs_args)

def isdvc(self, path, **kwargs) -> bool:
Expand Down
26 changes: 24 additions & 2 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from dvc.stage import Stage
from dvc.types import DictStrAny
from dvc_data.hashfile.state import StateBase
from dvc_data.index import DataIndex
from dvc_data.index import DataIndex, DataIndexEntry

from .experiments import Experiments
from .index import Index
Expand Down Expand Up @@ -285,7 +285,7 @@ def check_graph(
new.check_graph()

@staticmethod
def open(url, *args, **kwargs): # noqa: A003
def open(url: Optional[str], *args, **kwargs) -> "Repo": # noqa: A003
from .open_repo import open_repo

return open_repo(url, *args, **kwargs)
Expand Down Expand Up @@ -375,6 +375,28 @@ def drop_data_index(self) -> None:
self.data_index.commit()
self._reset()

def get_data_index_entry(
self,
path: str,
workspace: str = "repo",
) -> Tuple["DataIndex", "DataIndexEntry"]:
if self.subrepos:
fs_path = self.dvcfs.from_os_path(path)
fs = self.dvcfs.fs
# pylint: disable-next=protected-access
key = fs._get_key_from_relative(fs_path)
# pylint: disable-next=protected-access
subrepo, _, key = fs._get_subrepo_info(key)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This subrepo lookup still feels like it belongs in Repo somewhere and not dvcfs, but at least this way the internal lookups are kept within the Repo level helper method and we don't need to touch the dvcfs internals in dvc.api methods

index = subrepo.index.data[workspace]
else:
index = self.index.data[workspace]
key = self.fs.path.relparts(path, self.root_dir)

try:
return index, index[key]
except KeyError as exc:
raise OutputNotFoundError(path, self) from exc

def __repr__(self):
return f"{self.__class__.__name__}: '{self.root_dir}'"

Expand Down
53 changes: 11 additions & 42 deletions dvc/repo/open_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import os
import tempfile
import threading
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Iterator, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple

from funcy import retry, wrap_with

Expand All @@ -18,16 +17,14 @@
logger = logging.getLogger(__name__)


@contextmanager
@map_scm_exception()
def _external_repo(
url,
rev: Optional[str] = None,
for_write: bool = False,
**kwargs,
) -> Iterator["Repo"]:
) -> "Repo":
logger.debug("Creating external repo %s@%s", url, rev)
path = _cached_clone(url, rev, for_write=for_write)
path = _cached_clone(url, rev)
# Local HEAD points to the tip of whatever branch we first cloned from
# (which may not be the default branch), use origin/HEAD here to get
# the tip of the default branch
Expand All @@ -38,11 +35,6 @@ def _external_repo(
config.update(kwargs.pop("config", None) or {})

main_root = "/"
if for_write:
# we already checked out needed revision
rev = None
main_root = path

repo_kwargs = dict(
root_dir=path,
url=url,
Expand All @@ -52,14 +44,7 @@ def _external_repo(
**kwargs,
)

repo = Repo(**repo_kwargs)

try:
yield repo
finally:
repo.close()
if for_write:
_remove(path)
return Repo(**repo_kwargs)


def open_repo(url, *args, **kwargs):
Expand Down Expand Up @@ -140,37 +125,32 @@ def _get_remote_config(url):
repo.close()


def _cached_clone(url, rev, for_write=False):
def _cached_clone(url, rev):
"""Clone an external git repo to a temporary directory.

Returns the path to a local temporary directory with the specified
revision checked out. If for_write is set prevents reusing this dir via
cache.
revision checked out.
"""
from shutil import copytree

# even if we have already cloned this repo, we may need to
# fetch/fast-forward to get specified rev
clone_path, shallow = _clone_default_branch(url, rev, for_write=for_write)
clone_path, shallow = _clone_default_branch(url, rev)

if not for_write and (url) in CLONES:
if url in CLONES:
return CLONES[url][0]

# Copy to a new dir to keep the clone clean
repo_path = tempfile.mkdtemp("dvc-erepo")
logger.debug("erepo: making a copy of %s clone", url)
copytree(clone_path, repo_path)

# Check out the specified revision
if for_write:
_git_checkout(repo_path, rev)
else:
CLONES[url] = (repo_path, shallow)
CLONES[url] = (repo_path, shallow)
return repo_path


@wrap_with(threading.Lock())
def _clone_default_branch(url, rev, for_write=False): # noqa: C901, PLR0912
def _clone_default_branch(url, rev): # noqa: C901, PLR0912
"""Get or create a clean clone of the url.

The cloned is reactualized with git pull unless rev is a known sha.
Expand Down Expand Up @@ -204,7 +184,7 @@ def _clone_default_branch(url, rev, for_write=False): # noqa: C901, PLR0912

logger.debug("erepo: git clone '%s' to a temporary dir", url)
clone_path = tempfile.mkdtemp("dvc-clone")
if not for_write and rev and not Git.is_sha(rev):
if rev and not Git.is_sha(rev):
# If rev is a tag or branch name try shallow clone first

try:
Expand Down Expand Up @@ -249,17 +229,6 @@ def _merge_upstream(git: "Git"):
pass


def _git_checkout(repo_path, rev):
from dvc.scm import Git

logger.debug("erepo: git checkout %s@%s", repo_path, rev)
git = Git(repo_path)
try:
git.checkout(rev)
finally:
git.close()


def _remove(path):
from dvc.utils.fs import remove

Expand Down
Loading