Skip to content

Commit

Permalink
lfs: add support for Git SSH URLs
Browse files Browse the repository at this point in the history
  • Loading branch information
sisp authored and pmrowla committed Feb 29, 2024
1 parent 4819e71 commit c44c577
Showing 1 changed file with 99 additions and 28 deletions.
127 changes: 99 additions & 28 deletions src/scmrepo/git/lfs/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json
import logging
import os
import re
import shutil
from abc import abstractmethod
from collections.abc import Iterable, Iterator
from contextlib import AbstractContextManager, contextmanager, suppress
from tempfile import NamedTemporaryFile
Expand All @@ -13,6 +16,7 @@
from fsspec.implementations.http import HTTPFileSystem
from funcy import cached_property

from scmrepo.git.backend.dulwich import _get_ssh_vendor
from scmrepo.git.credentials import Credential, CredentialNotFoundError

from .exceptions import LFSError
Expand All @@ -35,19 +39,12 @@ class LFSClient(AbstractContextManager):
_SESSION_RETRIES = 5
_SESSION_BACKOFF_FACTOR = 0.1

def __init__(
self,
url: str,
git_url: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
):
def __init__(self, url: str):
"""
Args:
url: LFS server URL.
"""
self.url = url
self.git_url = git_url
self.headers: dict[str, str] = headers or {}

def __exit__(self, *args, **kwargs):
self.close()
Expand Down Expand Up @@ -84,23 +81,18 @@ def loop(self):

@classmethod
def from_git_url(cls, git_url: str) -> "LFSClient":
if git_url.endswith(".git"):
url = f"{git_url}/info/lfs"
else:
url = f"{git_url}.git/info/lfs"
return cls(url, git_url=git_url)
if git_url.startswith(("ssh://", "git@")):
return _SSHLFSClient.from_git_url(git_url)
if git_url.startswith("https://"):
return _HTTPLFSClient.from_git_url(git_url)
raise NotImplementedError(f"Unsupported Git URL: {git_url}")

def close(self):
pass

def _get_auth(self) -> Optional[aiohttp.BasicAuth]:
try:
creds = Credential(url=self.git_url).fill()
if creds.username and creds.password:
return aiohttp.BasicAuth(creds.username, creds.password)
except CredentialNotFoundError:
pass
return None
@abstractmethod
def _get_auth_header(self, *, upload: bool) -> dict:
...

async def _batch_request(
self,
Expand All @@ -120,9 +112,10 @@ async def _batch_request(
if ref:
body["ref"] = [{"name": ref}]
session = await self._fs.set_session()
headers = dict(self.headers)
headers["Accept"] = self.JSON_CONTENT_TYPE
headers["Content-Type"] = self.JSON_CONTENT_TYPE
headers = {
"Accept": self.JSON_CONTENT_TYPE,
"Content-Type": self.JSON_CONTENT_TYPE,
}
try:
async with session.post(
url,
Expand All @@ -134,13 +127,12 @@ async def _batch_request(
except aiohttp.ClientResponseError as exc:
if exc.status != 401:
raise
auth = self._get_auth()
if auth is None:
auth_header = self._get_auth_header(upload=upload)
if not auth_header:
raise
async with session.post(
url,
auth=auth,
headers=headers,
headers={**headers, **auth_header},
json=body,
raise_for_status=True,
) as resp:
Expand Down Expand Up @@ -186,6 +178,85 @@ async def _get_one(from_path: str, to_path: str, **kwargs):
download = sync_wrapper(_download)


class _HTTPLFSClient(LFSClient):
def __init__(self, url: str, git_url: str):
"""
Args:
url: LFS server URL.
git_url: Git HTTP URL.
"""
super().__init__(url)
self.git_url = git_url

@classmethod
def from_git_url(cls, git_url: str) -> "_HTTPLFSClient":
if git_url.endswith(".git"):
url = f"{git_url}/info/lfs"
else:
url = f"{git_url}.git/info/lfs"
return cls(url, git_url=git_url)

def _get_auth_header(self, *, upload: bool) -> dict:
try:
creds = Credential(url=self.git_url).fill()
if creds.username and creds.password:
return {
aiohttp.hdrs.AUTHORIZATION: aiohttp.BasicAuth(
creds.username, creds.password
).encode()
}
except CredentialNotFoundError:
pass
return {}


class _SSHLFSClient(LFSClient):
_URL_PATTERN = re.compile(
r"(?:ssh://)?git@(?P<host>\S+?)(?::(?P<port>\d+))?(?:[:/])(?P<path>\S+?)\.git"
)

def __init__(self, url: str, host: str, port: int, path: str):
"""
Args:
url: LFS server URL.
host: Git SSH server host.
port: Git SSH server port.
path: Git project path.
"""
super().__init__(url)
self.host = host
self.port = port
self.path = path
self._ssh = _get_ssh_vendor()

@classmethod
def from_git_url(cls, git_url: str) -> "_SSHLFSClient":
result = cls._URL_PATTERN.match(git_url)
if not result:
raise ValueError(f"Invalid Git SSH URL: {git_url}")
host, port, path = result.group("host", "port", "path")
url = f"https://{host}/{path}.git/info/lfs"
return cls(url, host, int(port or 22), path)

def _get_auth_header(self, *, upload: bool) -> dict:
return self._git_lfs_authenticate(
self.host, self.port, f"{self.path}.git", upload=upload
).get("header", {})

def _git_lfs_authenticate(
self, host: str, port: int, path: str, *, upload: bool = False
) -> dict:
action = "upload" if upload else "download"
return json.loads(
self._ssh.run_command(
command=f"git-lfs-authenticate {path} {action}",
host=host,
port=port,
username="git",
).read()
)


@contextmanager
def _as_atomic(to_info: str, create_parents: bool = False) -> Iterator[str]:
parent = os.path.dirname(to_info)
Expand Down

0 comments on commit c44c577

Please sign in to comment.