Skip to content

Commit

Permalink
support scp-style shorthand urls with users other than git@ (#346)
Browse files Browse the repository at this point in the history
* support scp-style shorthand urls with users other than git@

* lfs: support scp-style urls with username other than git

* do not add .git to the path

* add .git/info/lfs suffix to the lfs api url

* fix regex
  • Loading branch information
skshetry authored Mar 22, 2024
1 parent 1096899 commit 8987d1a
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 20 deletions.
11 changes: 6 additions & 5 deletions src/scmrepo/git/backend/pygit2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
import stat
from collections.abc import Generator, Iterable, Iterator, Mapping
from collections.abc import Iterable, Iterator, Mapping
from contextlib import contextmanager
from io import BytesIO, StringIO, TextIOWrapper
from typing import (
Expand All @@ -25,6 +25,7 @@
from scmrepo.git.backend.base import BaseGitBackend, SyncStatus
from scmrepo.git.config import Config
from scmrepo.git.objects import GitCommit, GitObject, GitTag
from scmrepo.urls import is_scp_style_url
from scmrepo.utils import relpath

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -636,7 +637,7 @@ def _merge_remote_branch(
raise SCMError("Unknown merge analysis result")

@contextmanager
def _get_remote(self, url: str) -> Generator["Remote", None, None]:
def _get_remote(self, url: str) -> Iterator["Remote"]:
"""Return a pygit2.Remote suitable for the specified Git URL or remote name."""
try:
remote = self.repo.remotes[url]
Expand All @@ -646,11 +647,11 @@ def _get_remote(self, url: str) -> Generator["Remote", None, None]:
except KeyError as exc:
raise SCMError(f"'{url}' is not a valid Git remote or URL") from exc

if os.name == "nt" and url.startswith("file://"):
url = url[len("file://") :]
if os.name == "nt":
url = url.removeprefix("file://")
remote = self.repo.remotes.create_anonymous(url)
parsed = urlparse(remote.url)
if parsed.scheme in ("git", "git+ssh", "ssh") or remote.url.startswith("git@"):
if parsed.scheme in ("git", "git+ssh", "ssh") or is_scp_style_url(remote.url):
raise NotImplementedError
yield remote

Expand Down
46 changes: 31 additions & 15 deletions src/scmrepo/git/lfs/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import logging
import os
import re
import shutil
from abc import abstractmethod
from collections.abc import Iterable, Iterator
Expand All @@ -10,6 +9,7 @@
from tempfile import NamedTemporaryFile
from time import time
from typing import TYPE_CHECKING, Any, Optional
from urllib.parse import urlparse

import aiohttp
from aiohttp_retry import ExponentialRetry, RetryClient
Expand All @@ -20,6 +20,7 @@

from scmrepo.git.backend.dulwich import _get_ssh_vendor
from scmrepo.git.credentials import Credential, CredentialNotFoundError
from scmrepo.urls import SCP_REGEX, is_scp_style_url

from .exceptions import LFSError
from .pointer import Pointer
Expand Down Expand Up @@ -84,7 +85,7 @@ def loop(self):

@classmethod
def from_git_url(cls, git_url: str) -> "LFSClient":
if git_url.startswith(("ssh://", "git@")):
if git_url.startswith("ssh://") or is_scp_style_url(git_url):
return _SSHLFSClient.from_git_url(git_url)
if git_url.startswith(("http://", "https://")):
return _HTTPLFSClient.from_git_url(git_url)
Expand Down Expand Up @@ -213,11 +214,9 @@ def _get_auth_header(self, *, upload: bool) -> dict:


class _SSHLFSClient(LFSClient):
_URL_PATTERN = re.compile(
r"(?:ssh://)?git@(?P<host>\S+?)(?::(?P<port>\d+))?(?:[:/])(?P<path>\S+?)\.git"
)

def __init__(self, url: str, host: str, port: int, path: str):
def __init__(
self, url: str, host: str, port: int, username: Optional[str], path: str
):
"""
Args:
url: LFS server URL.
Expand All @@ -228,33 +227,50 @@ def __init__(self, url: str, host: str, port: int, path: str):
super().__init__(url)
self.host = host
self.port = port
self.username = username
self.path = path
self._ssh = _get_ssh_vendor()

@classmethod
def from_git_url(cls, git_url: str) -> "_SSHLFSClient":
result = cls._URL_PATTERN.match(git_url)
if not result:
if scp_match := SCP_REGEX.match(git_url):
# Add an ssh:// prefix and replace the ':' with a '/'.
git_url = scp_match.expand(r"ssh://\1\2/\3")

parsed = urlparse(git_url)
if parsed.scheme != "ssh" or not parsed.hostname:
raise ValueError(f"Invalid Git SSH URL: {git_url}")
host, port, path = result.group("host", "port", "path")
url = f"https://{host}/{path}.git/info/lfs"
return cls(url, host, int(port or 22), path)

host = parsed.hostname
port = parsed.port or 22
path = parsed.path.lstrip("/")
username = parsed.username

url_path = path.removesuffix(".git") + ".git/info/lfs"
url = f"https://{host}/{url_path}"
return cls(url, host, port, username, path)

def _get_auth_header(self, *, upload: bool) -> dict:
return self._git_lfs_authenticate(
self.host, self.port, f"{self.path}.git", upload=upload
self.host, self.port, self.username, self.path, upload=upload
).get("header", {})

def _git_lfs_authenticate(
self, host: str, port: int, path: str, *, upload: bool = False
self,
host: str,
port: int,
username: Optional[str],
path: str,
*,
upload: bool = False,
) -> dict:
action = "upload" if upload else "download"
return json.loads(
self._ssh.run_command(
command=f"git-lfs-authenticate {path} {action}",
host=host,
port=port,
username="git",
username=username,
).read()
)

Expand Down
21 changes: 21 additions & 0 deletions src/scmrepo/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import re

# from https://github.com/pypa/pip/blob/303fed36c1771de4063063a866776a9103972317/src/pip/_internal/vcs/git.py#L40
# SCP (Secure copy protocol) shorthand. e.g. 'git@example.com:foo/bar.git'
SCP_REGEX = re.compile(
r"""^
# Optional user, e.g. 'git@'
(\w+@)?
# Server, e.g. 'github.com'.
([^/:]+):
# The server-side path. e.g. 'user/project.git'. Must start with an
# alphanumeric character so as not to be confusable with a Windows paths
# like 'C:/foo/bar' or 'C:\foo\bar'.
(\w[^:]*)
$""",
re.VERBOSE,
)


def is_scp_style_url(url: str) -> bool:
return bool(SCP_REGEX.match(url))
2 changes: 2 additions & 0 deletions tests/test_pygit2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def test_pygit_stash_apply_conflicts(
"url",
[
"git@github.com:iterative/scmrepo.git",
"github.com:iterative/scmrepo.git",
"user@github.com:iterative/scmrepo.git",
"ssh://login@server.com:12345/repository.git",
],
)
Expand Down
31 changes: 31 additions & 0 deletions tests/test_urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from scmrepo.urls import is_scp_style_url


@pytest.mark.parametrize(
"url",
[
"git@github.com:iterative/scmrepo.git",
"github.com:iterative/scmrepo.git",
"user@github.com:iterative/scmrepo.git",
],
)
def test_scp_url(url: str):
assert is_scp_style_url(url)


@pytest.mark.parametrize(
"url",
[
r"C:\foo\bar",
"C:/foo/bar",
"/home/user/iterative/scmrepo/git",
"~/iterative/scmrepo/git",
"ssh://login@server.com:12345/repository.git",
"https://user:password@github.com/iterative/scmrepo.git",
"https://github.com/iterative/scmrepo.git",
],
)
def test_scp_url_invalid(url: str):
assert not is_scp_style_url(url)

0 comments on commit 8987d1a

Please sign in to comment.