Skip to content

Commit

Permalink
lfs: authenticate with Git credentials only to Batch API requests
Browse files Browse the repository at this point in the history
  • Loading branch information
sisp authored and pmrowla committed Jan 29, 2024
1 parent 87a9e91 commit c81d7f4
Showing 1 changed file with 26 additions and 35 deletions.
61 changes: 26 additions & 35 deletions src/scmrepo/git/lfs/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import logging
from collections.abc import Awaitable, Iterable
from collections.abc import Iterable
from contextlib import AbstractContextManager
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Optional

import aiohttp
from dvc_http import HTTPFileSystem
Expand Down Expand Up @@ -32,29 +31,6 @@ def _prepare_credentials(self, **config):
return {}


def _authed(f: Callable[..., Awaitable]):
"""Set credentials and retry the given coroutine if needed."""

# pylint: disable=protected-access
@wraps(f) # type: ignore[arg-type]
async def wrapper(self, *args, **kwargs):
try:
return await f(self, *args, **kwargs)
except aiohttp.ClientResponseError as exc:
if exc.status != 401:
raise
session = await self._set_session()
if session.auth:
raise
auth = self._get_auth()
if auth is None:
raise
self._session._auth = auth
return await f(self, *args, **kwargs)

return wrapper


class LFSClient(AbstractContextManager):
"""Naive read-only LFS HTTP client."""

Expand Down Expand Up @@ -112,7 +88,6 @@ def _get_auth(self) -> Optional[aiohttp.BasicAuth]:
async def _set_session(self) -> aiohttp.ClientSession:
return await self.fs.fs.set_session()

@_authed
async def _batch_request(
self,
objects: Iterable[Pointer],
Expand All @@ -134,14 +109,30 @@ async def _batch_request(
headers = dict(self.headers)
headers["Accept"] = self.JSON_CONTENT_TYPE
headers["Content-Type"] = self.JSON_CONTENT_TYPE
async with session.post(
url,
headers=headers,
json=body,
) as resp:
return await resp.json()

@_authed
try:
async with session.post(
url,
headers=headers,
json=body,
raise_for_status=True,
) as resp:
data = await resp.json()
except aiohttp.ClientResponseError as exc:
if exc.status != 401:
raise
auth = self._get_auth()
if auth is None:
raise
async with session.post(
url,
auth=auth,
headers=headers,
json=body,
raise_for_status=True,
) as resp:
data = await resp.json()
return data

async def _download(
self,
storage: "LFSStorage",
Expand Down

0 comments on commit c81d7f4

Please sign in to comment.