diff --git a/dank_mids/helpers/session.py b/dank_mids/helpers/session.py index 28be55ba..30358262 100644 --- a/dank_mids/helpers/session.py +++ b/dank_mids/helpers/session.py @@ -6,11 +6,11 @@ from itertools import chain from threading import get_ident from random import random -from typing import Any, overload +from typing import Any, Callable, List, Optional, overload import msgspec from aiohttp import ClientSession as DefaultClientSession -from aiohttp import ClientTimeout +from aiohttp import ClientTimeout, TCPConnector from aiohttp.client_exceptions import ClientResponseError from aiohttp.typedefs import JSONDecoder from aiolimiter import AsyncLimiter @@ -75,10 +75,10 @@ def __new__(cls, value, phrase, description=''): limiter = AsyncLimiter(5, 0.1) # 50 requests/second @overload -async def post(endpoint: str, *args, loads = decode.raw, **kwargs) -> RawResponse:... +async def post(endpoint: str, *args, loads: Callable[[Any], RawResponse], **kwargs) -> RawResponse:... @overload -async def post(endpoint: str, *args, loads = decode.jsonrpc_batch, **kwargs) -> JSONRPCBatchResponse:... -async def post(endpoint: str, *args, loads: JSONDecoder = None, **kwargs) -> Any: +async def post(endpoint: str, *args, loads: Callable[[Any], JSONRPCBatchResponse], **kwargs) -> JSONRPCBatchResponse:... +async def post(endpoint: str, *args, loads: Optional[JSONDecoder] = None, **kwargs) -> Any: """Returns decoded json data from `endpoint`""" session = await get_session() return await session.post(endpoint, *args, loads=loads, **kwargs) @@ -87,7 +87,7 @@ async def get_session() -> "ClientSession": return await _get_session_for_thread(get_ident()) class ClientSession(DefaultClientSession): - async def post(self, endpoint: str, *args, loads: JSONDecoder = None, _retry_after: int = 1, **kwargs) -> bytes: + async def post(self, endpoint: str, *args, loads: Optional[JSONDecoder] = None, _retry_after: int = 1, **kwargs) -> bytes: # type: ignore [override] # Process input arguments. if isinstance(kwargs.get('data'), PartialRequest): logger.debug("making request for %s", kwargs['data']) @@ -134,7 +134,11 @@ async def _get_session_for_thread(thread_ident: int) -> ClientSession: This makes our ClientSession threadsafe just in case. Most everything should be run in main thread though. """ - timeout = ClientTimeout(ENVIRONMENT_VARIABLES.AIOHTTP_TIMEOUT) - return ClientSession(headers={'content-type': 'application/json'}, timeout=timeout, raise_for_status=True) + return ClientSession( + connector = TCPConnector(force_close=True), + headers = {'content-type': 'application/json'}, + timeout = ClientTimeout(ENVIRONMENT_VARIABLES.AIOHTTP_TIMEOUT), # type: ignore [arg-type] + raise_for_status = True, + ) -_limited = [] \ No newline at end of file +_limited: List[ClientSession] = [] \ No newline at end of file