Skip to content

Commit

Permalink
feat: better rate limiting logic (#317)
Browse files Browse the repository at this point in the history
* feat: better rate-limiting logic

* fix: if -> while

* feat: better rate-limiting logic

* fix: suppress StopIteration warning

* chore: round log msg to 3 decimals

* chore: `black .`

* chore: suppress spammy logs

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
BobTheBuidler and github-actions[bot] authored Dec 4, 2024
1 parent 4ae7575 commit a3c60aa
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 38 deletions.
2 changes: 1 addition & 1 deletion dank_mids/brownie_patch/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
*args: The arguments to be encoded.
"""

# We do this so ypricemagic's checksum cache monkey patch will work,
# We assign this variable so ypricemagic's checksum cache monkey patch will work,
# This is only relevant to you if your project uses ypricemagic as well.
to_checksum_address = Address.checksum

Expand Down
98 changes: 61 additions & 37 deletions dank_mids/helpers/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from itertools import chain
from random import random
from threading import get_ident
from time import time
from typing import Any, Callable, List, Optional, overload

import msgspec
Expand Down Expand Up @@ -123,8 +124,18 @@ async def get_session() -> "DankClientSession":
return await _get_session_for_thread(get_ident())


_RETRY_AFTER = 1.0


class DankClientSession(ClientSession):
async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DECODER, _retry_after: float = 1, **kwargs) -> bytes: # type: ignore [override]
_limited = False
_last_rate_limited_at = 0
_continue_requests_at = 0

async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DECODER, **kwargs) -> bytes: # type: ignore [override]
if (now := time()) < self._continue_requests_at:
await asyncio.sleep(self._continue_requests_at - now)

# Process input arguments.
if isinstance(kwargs.get("data"), PartialRequest):
logger.debug("making request for %s", kwargs["data"])
Expand All @@ -142,40 +153,56 @@ async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DEC
return response_data
except ClientResponseError as ce:
if ce.status == HTTPStatusExtended.TOO_MANY_REQUESTS: # type: ignore [attr-defined]
try_after = float(ce.headers.get("Retry-After", _retry_after * 1.5)) # type: ignore [union-attr]
if self not in _limited:
_limited.append(self)
logger.info("You're being rate limited by your node provider")
logger.info(
"Its all good, dank_mids has this handled, but you might get results slower than you'd like"
)
logger.info(f"rate limited: retrying after {try_after}s")
await asyncio.sleep(try_after)
if try_after > 30:
logger.warning("severe rate limiting from your provider")
return await self.post(
endpoint, *args, loads=loads, _retry_after=try_after, **kwargs
await self.handle_too_many_requests(ce)
else:
try:
if ce.status not in RETRY_FOR_CODES or tried >= 5:
logger.debug(
"response failed with status %s", HTTPStatusExtended(ce.status)
)
raise ce
except ValueError as ve:
raise (
ce if str(ve).endswith("is not a valid HTTPStatusExtended") else ve
) from ve

sleep = random()
await asyncio.sleep(sleep)
logger.debug(
"response failed with status %s, retrying in %ss",
HTTPStatusExtended(ce.status),
round(sleep, 2),
)

try:
if ce.status not in RETRY_FOR_CODES or tried >= 5:
logger.debug(
"response failed with status %s", HTTPStatusExtended(ce.status)
)
raise ce
except ValueError as ve:
raise (
ce if str(ve).endswith("is not a valid HTTPStatusExtended") else ve
) from ve

sleep = random()
await asyncio.sleep(sleep)
logger.debug(
"response failed with status %s, retrying in %ss",
HTTPStatusExtended(ce.status),
round(sleep, 2),
)
tried += 1
tried += 1

async def handle_too_many_requests(self, error: ClientResponseError) -> None:
now = time()
self._last_rate_limited_at = now
retry_after = float(error.headers.get("Retry-After", _RETRY_AFTER))
resume_at = max(
self._continue_requests_at + retry_after,
self._last_rate_limited_at + retry_after,
)
retry_after = resume_at - now
self._continue_requests_at = resume_at

self._log_rate_limited(retry_after)
await asyncio.sleep(retry_after)

if retry_after > 30:
logger.warning("severe rate limiting from your provider")

def _log_rate_limited(self, try_after: float) -> None:
if not self._limited:
self._limited = True
logger.info("You're being rate limited by your node provider")
logger.info(
"Its all good, dank_mids has this handled, but you might get results slower than you'd like"
)
if try_after < 5:
logger.debug("rate limited: retrying after %.3fs", try_after)
else:
logger.info("rate limited: retrying after %.3fs", try_after)


@alru_cache(maxsize=None)
Expand All @@ -191,6 +218,3 @@ async def _get_session_for_thread(thread_ident: int) -> DankClientSession:
raise_for_status=True,
read_bufsize=2**20, # 1mb
)


_limited: List[DankClientSession] = []

0 comments on commit a3c60aa

Please sign in to comment.