Skip to content

Commit

Permalink
mypy: shiny for aiohttp
Browse files Browse the repository at this point in the history
  • Loading branch information
jvansanten committed Sep 12, 2024
1 parent 7839776 commit b0b3f13
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions ampel/ztf/t3/skyportal/SkyPortalClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
from ampel.util.mappings import flatten_dict

if TYPE_CHECKING:
from typing import Unpack

from aiohttp.client import _RequestOptions

from ampel.config.AmpelConfig import AmpelConfig
from ampel.content.DataPoint import DataPoint
from ampel.view.T2DocView import T2DocView
Expand Down Expand Up @@ -185,9 +189,7 @@ def validate(cls, value: dict) -> Any:
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

self._request_kwargs = {
"headers": {"Authorization": f"token {self.token.get()}"}
}
self._auth_header = {"Authorization": f"token {self.token.get()}"}
self._ids: dict[str, dict[str, int]] = {}
self._session: None | aiohttp.ClientSession = None
self._semaphore: None | asyncio.Semaphore = None
Expand All @@ -210,7 +212,7 @@ async def request(
endpoint: str,
raise_exc: bool,
_decode_json: None,
**kwargs: dict[str, Any],
**kwargs: "Unpack[_RequestOptions]",
) -> aiohttp.ClientResponse:
...

Expand All @@ -221,7 +223,7 @@ async def request(
endpoint: str,
raise_exc: bool,
_decode_json: bool,
**kwargs: dict[str, Any],
**kwargs: "Unpack[_RequestOptions]",
) -> dict[str, Any]:
...

Expand All @@ -241,7 +243,7 @@ async def request(
endpoint: str,
raise_exc: bool = True,
_decode_json: None | bool = True,
**kwargs: dict[str, Any],
**kwargs: "Unpack[_RequestOptions]",
) -> aiohttp.ClientResponse | dict[str, Any]:
if self._session is None or self._semaphore is None:
raise ValueError(
Expand All @@ -251,6 +253,7 @@ async def request(
url = self.base_url + endpoint
else:
url = self.base_url + "/api/" + endpoint
kwargs["headers"] = dict(kwargs.get("headers") or {}) | self._auth_header
labels = (verb, endpoint.split("/")[0])
async with self._semaphore:
with (
Expand All @@ -265,7 +268,7 @@ async def request(
stat_concurrent_requests.labels(*labels).track_inprogress(),
):
async with self._session.request(
verb, url, **{**self._request_kwargs, **kwargs}
verb, url, **kwargs
) as response:
if response.status == 429 or response.status >= 500:
response.raise_for_status()
Expand Down

0 comments on commit b0b3f13

Please sign in to comment.