diff --git a/dank_mids/brownie_patch/call.py b/dank_mids/brownie_patch/call.py index 245f2f4b..e79175a9 100644 --- a/dank_mids/brownie_patch/call.py +++ b/dank_mids/brownie_patch/call.py @@ -64,7 +64,7 @@ *args: The arguments to be encoded. """ -# We do this so ypricemagic's checksum cache monkey patch will work, +# We assign this variable so ypricemagic's checksum cache monkey patch will work, # This is only relevant to you if your project uses ypricemagic as well. to_checksum_address = Address.checksum diff --git a/dank_mids/helpers/_session.py b/dank_mids/helpers/_session.py index e33f345a..2e30065f 100644 --- a/dank_mids/helpers/_session.py +++ b/dank_mids/helpers/_session.py @@ -5,6 +5,7 @@ from itertools import chain from random import random from threading import get_ident +from time import time from typing import Any, Callable, List, Optional, overload import msgspec @@ -123,8 +124,18 @@ async def get_session() -> "DankClientSession": return await _get_session_for_thread(get_ident()) +_RETRY_AFTER = 1.0 + + class DankClientSession(ClientSession): - async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DECODER, _retry_after: float = 1, **kwargs) -> bytes: # type: ignore [override] + _limited = False + _last_rate_limited_at = 0 + _continue_requests_at = 0 + + async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DECODER, **kwargs) -> bytes: # type: ignore [override] + if (now := time()) < self._continue_requests_at: + await asyncio.sleep(self._continue_requests_at - now) + # Process input arguments. if isinstance(kwargs.get("data"), PartialRequest): logger.debug("making request for %s", kwargs["data"]) @@ -142,40 +153,56 @@ async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DEC return response_data except ClientResponseError as ce: if ce.status == HTTPStatusExtended.TOO_MANY_REQUESTS: # type: ignore [attr-defined] - try_after = float(ce.headers.get("Retry-After", _retry_after * 1.5)) # type: ignore [union-attr] - if self not in _limited: - _limited.append(self) - logger.info("You're being rate limited by your node provider") - logger.info( - "Its all good, dank_mids has this handled, but you might get results slower than you'd like" - ) - logger.info(f"rate limited: retrying after {try_after}s") - await asyncio.sleep(try_after) - if try_after > 30: - logger.warning("severe rate limiting from your provider") - return await self.post( - endpoint, *args, loads=loads, _retry_after=try_after, **kwargs + await self.handle_too_many_requests(ce) + else: + try: + if ce.status not in RETRY_FOR_CODES or tried >= 5: + logger.debug( + "response failed with status %s", HTTPStatusExtended(ce.status) + ) + raise ce + except ValueError as ve: + raise ( + ce if str(ve).endswith("is not a valid HTTPStatusExtended") else ve + ) from ve + + sleep = random() + await asyncio.sleep(sleep) + logger.debug( + "response failed with status %s, retrying in %ss", + HTTPStatusExtended(ce.status), + round(sleep, 2), ) - - try: - if ce.status not in RETRY_FOR_CODES or tried >= 5: - logger.debug( - "response failed with status %s", HTTPStatusExtended(ce.status) - ) - raise ce - except ValueError as ve: - raise ( - ce if str(ve).endswith("is not a valid HTTPStatusExtended") else ve - ) from ve - - sleep = random() - await asyncio.sleep(sleep) - logger.debug( - "response failed with status %s, retrying in %ss", - HTTPStatusExtended(ce.status), - round(sleep, 2), - ) - tried += 1 + tried += 1 + + async def handle_too_many_requests(self, error: ClientResponseError) -> None: + now = time() + self._last_rate_limited_at = now + retry_after = float(error.headers.get("Retry-After", _RETRY_AFTER)) + resume_at = max( + self._continue_requests_at + retry_after, + self._last_rate_limited_at + retry_after, + ) + retry_after = resume_at - now + self._continue_requests_at = resume_at + + self._log_rate_limited(retry_after) + await asyncio.sleep(retry_after) + + if retry_after > 30: + logger.warning("severe rate limiting from your provider") + + def _log_rate_limited(self, try_after: float) -> None: + if not self._limited: + self._limited = True + logger.info("You're being rate limited by your node provider") + logger.info( + "Its all good, dank_mids has this handled, but you might get results slower than you'd like" + ) + if try_after < 5: + logger.debug("rate limited: retrying after %.3fs", try_after) + else: + logger.info("rate limited: retrying after %.3fs", try_after) @alru_cache(maxsize=None) @@ -191,6 +218,3 @@ async def _get_session_for_thread(thread_ident: int) -> DankClientSession: raise_for_status=True, read_bufsize=2**20, # 1mb ) - - -_limited: List[DankClientSession] = []