diff --git a/src/scmrepo/git/backend/base.py b/src/scmrepo/git/backend/base.py index e0188590..7359267a 100644 --- a/src/scmrepo/git/backend/base.py +++ b/src/scmrepo/git/backend/base.py @@ -263,6 +263,9 @@ def fetch_refspecs( returns True the local ref will be overwritten. Callback will be of the form: on_diverged(local_refname, remote_sha) + + Returns: + Mapping of local_refname to sync status. """ @abstractmethod diff --git a/src/scmrepo/git/backend/pygit2/__init__.py b/src/scmrepo/git/backend/pygit2/__init__.py index da3a1e06..42fad5dc 100644 --- a/src/scmrepo/git/backend/pygit2/__init__.py +++ b/src/scmrepo/git/backend/pygit2/__init__.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: - from pygit2 import Signature + from pygit2 import Oid, Signature from pygit2.remote import Remote # type: ignore from pygit2.repository import Repository @@ -551,7 +551,8 @@ def _merge_remote_branch( raise SCMError("Unknown merge analysis result") @contextmanager - def get_remote(self, url: str) -> Generator["Remote", None, None]: + def _get_remote(self, url: str) -> Generator["Remote", None, None]: + """Return a pygit2.Remote suitable for the specified Git URL or remote name.""" try: remote = self.repo.remotes[url] url = remote.url @@ -577,57 +578,84 @@ def fetch_refspecs( progress: Callable[["GitProgressEvent"], None] = None, **kwargs, ) -> Mapping[str, SyncStatus]: + import fnmatch + from pygit2 import GitError from .callbacks import RemoteCallbacks - if isinstance(refspecs, str): - refspecs = [refspecs] + refspecs = self._refspecs_list(refspecs, force=force) - with self.get_remote(url) as remote: - fetch_refspecs: List[str] = [] - for refspec in refspecs: - if ":" in refspec: - lh, rh = refspec.split(":") - else: - lh = rh = refspec - if not rh.startswith("refs/"): - rh = f"refs/heads/{rh}" - if not lh.startswith("refs/"): - lh = f"refs/heads/{lh}" - rh = rh[len("refs/") :] - refspec = f"+{lh}:refs/remotes/{remote.name}/{rh}" - fetch_refspecs.append(refspec) - - logger.debug("fetch_refspecs: %s", fetch_refspecs) + # libgit2 rejects diverged refs but does not have a callback to notify + # when a ref was rejected so we have to determine whether no callback + # means up to date or rejected + def _default_status( + src: str, dst: str, remote_refs: Dict[str, "Oid"] + ) -> SyncStatus: + try: + if remote_refs[src] != self.repo.references[dst].target: + return SyncStatus.DIVERGED + except KeyError: + # remote_refs lookup is skipped when force is set, refs cannot + # be diverged on force + pass + return SyncStatus.UP_TO_DATE + + with self._get_remote(url) as remote: with reraise( GitError, SCMError(f"Git failed to fetch ref from '{url}'"), ): with RemoteCallbacks(progress=progress) as cb: + remote_refs: Dict[str, "Oid"] = ( + { + head["name"]: head["oid"] + for head in remote.ls_remotes(callbacks=cb) + } + if not force + else {} + ) remote.fetch( - refspecs=fetch_refspecs, + refspecs=refspecs, callbacks=cb, + message="fetch", ) result: Dict[str, "SyncStatus"] = {} - for refspec in fetch_refspecs: - _, rh = refspec.split(":") - if not rh.endswith("*"): - refname = rh.split("/", 3)[-1] - refname = f"refs/{refname}" - result[refname] = self._merge_remote_branch( - rh, refname, force, on_diverged - ) - continue - rh = rh.rstrip("*").rstrip("/") + "/" - for branch in self.iter_refs(base=rh): - refname = f"refs/{branch[len(rh):]}" - result[refname] = self._merge_remote_branch( - branch, refname, force, on_diverged - ) + for refspec in refspecs: + lh, rh = refspec.split(":") + if lh.endswith("*"): + assert rh.endswith("*") + lh_prefix = lh[:-1] + rh_prefix = rh[:-1] + for refname in remote_refs: + if fnmatch.fnmatch(refname, lh): + src = refname + dst = f"{rh_prefix}{refname[len(lh_prefix):]}" + result[dst] = cb.result.get( + src, _default_status(src, dst, remote_refs) + ) + else: + result[rh] = cb.result.get(lh, _default_status(lh, rh, remote_refs)) + return result + @staticmethod + def _refspecs_list( + refspecs: Union[str, Iterable[str]], + force: bool = False, + ) -> List[str]: + if isinstance(refspecs, str): + if force and not refspecs.startswith("+"): + refspecs = f"+{refspecs}" + return [refspecs] + if force: + return [ + (refspec if refspec.startswith("+") else f"+{refspec}") + for refspec in refspecs + ] + return list(refspecs) + def _stash_iter(self, ref: str): raise NotImplementedError diff --git a/src/scmrepo/git/backend/pygit2/callbacks.py b/src/scmrepo/git/backend/pygit2/callbacks.py index 5ddf5ed1..9107ad8b 100644 --- a/src/scmrepo/git/backend/pygit2/callbacks.py +++ b/src/scmrepo/git/backend/pygit2/callbacks.py @@ -1,13 +1,15 @@ from contextlib import AbstractContextManager from types import TracebackType -from typing import TYPE_CHECKING, Callable, Optional, Type, Union +from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union from pygit2 import RemoteCallbacks as _RemoteCallbacks +from scmrepo.git.backend.base import SyncStatus from scmrepo.git.credentials import Credential, CredentialNotFoundError from scmrepo.progress import GitProgressReporter if TYPE_CHECKING: + from pygit2 import Oid from pygit2.credentials import Keypair, Username, UserPass from scmrepo.progress import GitProgressEvent @@ -27,6 +29,7 @@ def __init__( self.progress = GitProgressReporter(progress) if progress else None self._store_credentials: Optional["Credential"] = None self._tried_credentials = False + self.result: Dict[str, SyncStatus] = {} def __exit__( self, @@ -66,3 +69,9 @@ def credentials( def _approve_credentials(self): if self._store_credentials: self._store_credentials.approve() + + def update_tips(self, refname: str, old: "Oid", new: "Oid"): + if old == new: + self.result[refname] = SyncStatus.UP_TO_DATE + else: + self.result[refname] = SyncStatus.SUCCESS diff --git a/tests/test_pygit2.py b/tests/test_pygit2.py index 33393993..b9f096a6 100644 --- a/tests/test_pygit2.py +++ b/tests/test_pygit2.py @@ -74,7 +74,7 @@ def test_pygit_stash_apply_conflicts( def test_pygit_ssh_error(tmp_dir: TmpDir, scm: Git, url): backend = Pygit2Backend(tmp_dir) with pytest.raises(NotImplementedError): - with backend.get_remote(url): + with backend._get_remote(url): # pylint: disable=protected-access pass