Skip to content

Commit

Permalink
replace deprecated GIT_* constants with pygit2.enums (#360)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored May 20, 2024
1 parent 06504c9 commit c12c7bf
Showing 1 changed file with 52 additions and 75 deletions.
127 changes: 52 additions & 75 deletions src/scmrepo/git/backend/pygit2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
if TYPE_CHECKING:
from pygit2 import Commit, Oid, Signature
from pygit2.config import Config as _Pygit2Config
from pygit2.enums import CheckoutStrategy
from pygit2.remotes import Remote
from pygit2.repository import Repository

Expand Down Expand Up @@ -246,17 +247,15 @@ def _get_signature(self, name: str) -> "Signature":
)

@staticmethod
def _get_checkout_strategy(strategy: Optional[int] = None):
from pygit2 import (
GIT_CHECKOUT_RECREATE_MISSING,
GIT_CHECKOUT_SAFE,
GIT_CHECKOUT_SKIP_LOCKED_DIRECTORIES,
)
def _get_checkout_strategy(
strategy: Optional["CheckoutStrategy"] = None,
) -> "CheckoutStrategy":
from pygit2.enums import CheckoutStrategy

if strategy is None:
strategy = GIT_CHECKOUT_SAFE | GIT_CHECKOUT_RECREATE_MISSING
strategy = CheckoutStrategy.SAFE | CheckoutStrategy.RECREATE_MISSING
if os.name == "nt":
strategy |= GIT_CHECKOUT_SKIP_LOCKED_DIRECTORIES
strategy |= CheckoutStrategy.SKIP_LOCKED_DIRECTORIES
return strategy

# Workaround to force git_backend_odb_pack to release open file handles
Expand Down Expand Up @@ -343,9 +342,12 @@ def checkout(
force: bool = False,
**kwargs,
):
from pygit2 import GIT_CHECKOUT_FORCE, GitError
from pygit2 import GitError
from pygit2.enums import CheckoutStrategy

strategy = self._get_checkout_strategy(GIT_CHECKOUT_FORCE if force else None)
strategy = self._get_checkout_strategy(
CheckoutStrategy.FORCE if force else None
)

with self.release_odb_handles():
if create_new:
Expand Down Expand Up @@ -613,7 +615,7 @@ def _merge_remote_branch(
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
) -> SyncStatus:
import pygit2
from pygit2.enums import MergeAnalysis

rh_rev = self.resolve_rev(rh)

Expand All @@ -627,16 +629,16 @@ def _merge_remote_branch(
self.set_ref(lh, rh_rev)
return SyncStatus.SUCCESS

if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
if merge_result & MergeAnalysis.UP_TO_DATE:
return SyncStatus.UP_TO_DATE
if merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
if merge_result & MergeAnalysis.FASTFORWARD:
self.set_ref(lh, rh_rev)
return SyncStatus.SUCCESS
if merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
if merge_result & MergeAnalysis.NORMAL:
if on_diverged and on_diverged(lh, rh_rev):
return SyncStatus.SUCCESS
return SyncStatus.DIVERGED
logger.debug("Unexpected merge result: %s", pygit2.GIT_MERGE_ANALYSIS_NORMAL)
logger.debug("Unexpected merge result: %s", MergeAnalysis.NORMAL)
raise SCMError("Unknown merge analysis result")

@contextmanager
Expand Down Expand Up @@ -779,7 +781,8 @@ def _stash_apply(
skip_conflicts: bool = False,
**kwargs,
):
from pygit2 import GIT_CHECKOUT_ALLOW_CONFLICTS, GitError
from pygit2 import GitError
from pygit2.enums import CheckoutStrategy

from scmrepo.git import Stash

Expand All @@ -788,7 +791,7 @@ def _apply(index):
self.repo.index.read(False)
strategy = self._get_checkout_strategy()
if skip_conflicts:
strategy |= GIT_CHECKOUT_ALLOW_CONFLICTS
strategy |= CheckoutStrategy.ALLOW_CONFLICTS
self.repo.stash_apply(
index, strategy=strategy, reinstate_index=reinstate_index
)
Expand Down Expand Up @@ -834,7 +837,8 @@ def diff(self, rev_a: str, rev_b: str, binary=False) -> str:
raise NotImplementedError

def reset(self, hard: bool = False, paths: Optional[Iterable[str]] = None):
from pygit2 import GIT_RESET_HARD, GIT_RESET_MIXED, IndexEntry
from pygit2 import IndexEntry
from pygit2.enums import ResetMode

self.repo.index.read(False)
if paths is not None:
Expand All @@ -847,9 +851,9 @@ def reset(self, hard: bool = False, paths: Optional[Iterable[str]] = None):
self.repo.index.add(IndexEntry(rel, obj.id, obj.filemode))
self.repo.index.write()
elif hard:
self.repo.reset(self.repo.head.target, GIT_RESET_HARD)
self.repo.reset(self.repo.head.target, ResetMode.HARD)
else:
self.repo.reset(self.repo.head.target, GIT_RESET_MIXED)
self.repo.reset(self.repo.head.target, ResetMode.MIXED)

def checkout_index(
self,
Expand All @@ -858,22 +862,17 @@ def checkout_index(
ours: bool = False,
theirs: bool = False,
):
from pygit2 import (
GIT_CHECKOUT_ALLOW_CONFLICTS,
GIT_CHECKOUT_FORCE,
GIT_CHECKOUT_RECREATE_MISSING,
GIT_CHECKOUT_SAFE,
)
from pygit2.enums import CheckoutStrategy

assert not (ours and theirs)
strategy = GIT_CHECKOUT_RECREATE_MISSING
strategy = CheckoutStrategy.RECREATE_MISSING
if force or ours or theirs:
strategy |= GIT_CHECKOUT_FORCE
strategy |= CheckoutStrategy.FORCE
else:
strategy |= GIT_CHECKOUT_SAFE
strategy |= CheckoutStrategy.SAFE

if ours or theirs:
strategy |= GIT_CHECKOUT_ALLOW_CONFLICTS
strategy |= CheckoutStrategy.ALLOW_CONFLICTS
strategy = self._get_checkout_strategy(strategy)

index = self.repo.index
Expand Down Expand Up @@ -910,18 +909,7 @@ def checkout_index(
def status(
self, ignored: bool = False, untracked_files: str = "all"
) -> tuple[Mapping[str, Iterable[str]], Iterable[str], Iterable[str]]:
from pygit2 import (
GIT_STATUS_IGNORED,
GIT_STATUS_INDEX_DELETED,
GIT_STATUS_INDEX_MODIFIED,
GIT_STATUS_INDEX_NEW,
GIT_STATUS_WT_DELETED,
GIT_STATUS_WT_MODIFIED,
GIT_STATUS_WT_NEW,
GIT_STATUS_WT_RENAMED,
GIT_STATUS_WT_TYPECHANGE,
GIT_STATUS_WT_UNREADABLE,
)
from pygit2.enums import FileStatus

staged: Mapping[str, list[str]] = {
"add": [],
Expand All @@ -932,19 +920,19 @@ def status(
untracked: list[str] = []

states = {
GIT_STATUS_WT_NEW: untracked,
GIT_STATUS_WT_MODIFIED: unstaged,
GIT_STATUS_WT_TYPECHANGE: staged["modify"],
GIT_STATUS_WT_DELETED: staged["modify"],
GIT_STATUS_WT_RENAMED: staged["modify"],
GIT_STATUS_INDEX_NEW: staged["add"],
GIT_STATUS_INDEX_MODIFIED: staged["modify"],
GIT_STATUS_INDEX_DELETED: staged["delete"],
GIT_STATUS_WT_UNREADABLE: untracked,
FileStatus.WT_NEW: untracked,
FileStatus.WT_MODIFIED: unstaged,
FileStatus.WT_TYPECHANGE: staged["modify"],
FileStatus.WT_DELETED: staged["modify"],
FileStatus.WT_RENAMED: staged["modify"],
FileStatus.INDEX_NEW: staged["add"],
FileStatus.INDEX_MODIFIED: staged["modify"],
FileStatus.INDEX_DELETED: staged["delete"],
FileStatus.WT_UNREADABLE: untracked,
}

if untracked_files != "no" and ignored:
states[GIT_STATUS_IGNORED] = untracked
states[FileStatus.IGNORED] = untracked

for file, state in self.repo.status(
untracked_files=untracked_files, ignored=ignored
Expand All @@ -970,15 +958,8 @@ def merge( # noqa: C901
msg: Optional[str] = None,
squash: bool = False,
) -> Optional[str]:
from pygit2 import (
GIT_MERGE_ANALYSIS_FASTFORWARD,
GIT_MERGE_ANALYSIS_NONE,
GIT_MERGE_ANALYSIS_UNBORN,
GIT_MERGE_ANALYSIS_UP_TO_DATE,
GIT_MERGE_PREFERENCE_FASTFORWARD_ONLY,
GIT_MERGE_PREFERENCE_NO_FASTFORWARD,
GitError,
)
from pygit2 import GitError
from pygit2.enums import MergeAnalysis, MergePreference

if commit and squash:
raise SCMError("Cannot merge with 'squash' and 'commit'")
Expand All @@ -991,9 +972,9 @@ def merge( # noqa: C901
except GitError as exc:
raise SCMError("Merge analysis failed") from exc

if analysis == GIT_MERGE_ANALYSIS_NONE:
if analysis == MergeAnalysis.NONE:
raise SCMError(f"'{rev}' cannot be merged into HEAD")
if analysis & GIT_MERGE_ANALYSIS_UP_TO_DATE:
if analysis & MergeAnalysis.UP_TO_DATE:
return None

try:
Expand All @@ -1006,15 +987,15 @@ def merge( # noqa: C901
raise MergeConflictError("Merge contained conflicts")

try:
if not (squash or ff_pref & GIT_MERGE_PREFERENCE_NO_FASTFORWARD):
if analysis & GIT_MERGE_ANALYSIS_FASTFORWARD:
if not (squash or ff_pref & MergePreference.NO_FASTFORWARD):
if analysis & MergeAnalysis.FASTFORWARD:
return self._merge_ff(rev, obj)

if analysis & GIT_MERGE_ANALYSIS_UNBORN:
if analysis & MergeAnalysis.UNBORN:
self.repo.set_head(obj.id)
return str(obj.id)

if ff_pref & GIT_MERGE_PREFERENCE_FASTFORWARD_ONLY:
if ff_pref & MergePreference.FASTFORWARD_ONLY:
raise SCMError(f"Cannot fast-forward HEAD to '{rev}'")

if commit:
Expand Down Expand Up @@ -1105,19 +1086,15 @@ def check_attr(
attr: str,
source: Optional[str] = None,
) -> Optional[Union[bool, str]]:
from pygit2 import (
GIT_ATTR_CHECK_FILE_THEN_INDEX,
GIT_ATTR_CHECK_INCLUDE_COMMIT,
GIT_ATTR_CHECK_INDEX_ONLY,
GitError,
)
from pygit2 import GitError
from pygit2.enums import AttrCheck

commit: Optional["Commit"] = None
flags = GIT_ATTR_CHECK_FILE_THEN_INDEX
flags = AttrCheck.FILE_THEN_INDEX
if source:
try:
commit, _ref = self._resolve_refish(source)
flags = GIT_ATTR_CHECK_INDEX_ONLY | GIT_ATTR_CHECK_INCLUDE_COMMIT
flags = AttrCheck.INDEX_ONLY | AttrCheck.INCLUDE_COMMIT
except (KeyError, GitError) as exc:
raise SCMError(f"Invalid commit '{source}'") from exc
try:
Expand Down

0 comments on commit c12c7bf

Please sign in to comment.