diff --git a/src/scmrepo/git/lfs/client.py b/src/scmrepo/git/lfs/client.py index 05bed28..cf9dcfd 100644 --- a/src/scmrepo/git/lfs/client.py +++ b/src/scmrepo/git/lfs/client.py @@ -1,6 +1,9 @@ +import json import logging import os +import re import shutil +from abc import abstractmethod from collections.abc import Iterable, Iterator from contextlib import AbstractContextManager, contextmanager, suppress from tempfile import NamedTemporaryFile @@ -13,6 +16,7 @@ from fsspec.implementations.http import HTTPFileSystem from funcy import cached_property +from scmrepo.git.backend.dulwich import _get_ssh_vendor from scmrepo.git.credentials import Credential, CredentialNotFoundError from .exceptions import LFSError @@ -35,19 +39,12 @@ class LFSClient(AbstractContextManager): _SESSION_RETRIES = 5 _SESSION_BACKOFF_FACTOR = 0.1 - def __init__( - self, - url: str, - git_url: Optional[str] = None, - headers: Optional[dict[str, str]] = None, - ): + def __init__(self, url: str): """ Args: url: LFS server URL. """ self.url = url - self.git_url = git_url - self.headers: dict[str, str] = headers or {} def __exit__(self, *args, **kwargs): self.close() @@ -84,23 +81,18 @@ def loop(self): @classmethod def from_git_url(cls, git_url: str) -> "LFSClient": - if git_url.endswith(".git"): - url = f"{git_url}/info/lfs" - else: - url = f"{git_url}.git/info/lfs" - return cls(url, git_url=git_url) + if git_url.startswith(("ssh://", "git@")): + return _SSHLFSClient.from_git_url(git_url) + if git_url.startswith("https://"): + return _HTTPLFSClient.from_git_url(git_url) + raise NotImplementedError(f"Unsupported Git URL: {git_url}") def close(self): pass - def _get_auth(self) -> Optional[aiohttp.BasicAuth]: - try: - creds = Credential(url=self.git_url).fill() - if creds.username and creds.password: - return aiohttp.BasicAuth(creds.username, creds.password) - except CredentialNotFoundError: - pass - return None + @abstractmethod + def _get_auth_header(self, *, upload: bool) -> dict: + ... async def _batch_request( self, @@ -120,9 +112,10 @@ async def _batch_request( if ref: body["ref"] = [{"name": ref}] session = await self._fs.set_session() - headers = dict(self.headers) - headers["Accept"] = self.JSON_CONTENT_TYPE - headers["Content-Type"] = self.JSON_CONTENT_TYPE + headers = { + "Accept": self.JSON_CONTENT_TYPE, + "Content-Type": self.JSON_CONTENT_TYPE, + } try: async with session.post( url, @@ -134,13 +127,12 @@ async def _batch_request( except aiohttp.ClientResponseError as exc: if exc.status != 401: raise - auth = self._get_auth() - if auth is None: + auth_header = self._get_auth_header(upload=upload) + if not auth_header: raise async with session.post( url, - auth=auth, - headers=headers, + headers={**headers, **auth_header}, json=body, raise_for_status=True, ) as resp: @@ -186,6 +178,85 @@ async def _get_one(from_path: str, to_path: str, **kwargs): download = sync_wrapper(_download) +class _HTTPLFSClient(LFSClient): + def __init__(self, url: str, git_url: str): + """ + Args: + url: LFS server URL. + git_url: Git HTTP URL. + """ + super().__init__(url) + self.git_url = git_url + + @classmethod + def from_git_url(cls, git_url: str) -> "_HTTPLFSClient": + if git_url.endswith(".git"): + url = f"{git_url}/info/lfs" + else: + url = f"{git_url}.git/info/lfs" + return cls(url, git_url=git_url) + + def _get_auth_header(self, *, upload: bool) -> dict: + try: + creds = Credential(url=self.git_url).fill() + if creds.username and creds.password: + return { + aiohttp.hdrs.AUTHORIZATION: aiohttp.BasicAuth( + creds.username, creds.password + ).encode() + } + except CredentialNotFoundError: + pass + return {} + + +class _SSHLFSClient(LFSClient): + _URL_PATTERN = re.compile( + r"(?:ssh://)?git@(?P\S+?)(?::(?P\d+))?(?:[:/])(?P\S+?)\.git" + ) + + def __init__(self, url: str, host: str, port: int, path: str): + """ + Args: + url: LFS server URL. + host: Git SSH server host. + port: Git SSH server port. + path: Git project path. + """ + super().__init__(url) + self.host = host + self.port = port + self.path = path + self._ssh = _get_ssh_vendor() + + @classmethod + def from_git_url(cls, git_url: str) -> "_SSHLFSClient": + result = cls._URL_PATTERN.match(git_url) + if not result: + raise ValueError(f"Invalid Git SSH URL: {git_url}") + host, port, path = result.group("host", "port", "path") + url = f"https://{host}/{path}.git/info/lfs" + return cls(url, host, int(port or 22), path) + + def _get_auth_header(self, *, upload: bool) -> dict: + return self._git_lfs_authenticate( + self.host, self.port, f"{self.path}.git", upload=upload + ).get("header", {}) + + def _git_lfs_authenticate( + self, host: str, port: int, path: str, *, upload: bool = False + ) -> dict: + action = "upload" if upload else "download" + return json.loads( + self._ssh.run_command( + command=f"git-lfs-authenticate {path} {action}", + host=host, + port=port, + username="git", + ).read() + ) + + @contextmanager def _as_atomic(to_info: str, create_parents: bool = False) -> Iterator[str]: parent = os.path.dirname(to_info)