diff --git a/dank_mids/ENVIRONMENT_VARIABLES.py b/dank_mids/ENVIRONMENT_VARIABLES.py index 53dcfa54..7ae8c300 100644 --- a/dank_mids/ENVIRONMENT_VARIABLES.py +++ b/dank_mids/ENVIRONMENT_VARIABLES.py @@ -35,6 +35,8 @@ MAX_MULTICALL_SIZE = _envs.create_env("MAX_MULTICALL_SIZE", int, default=10_000) # Max number of rpc calls to include in one batch call MAX_JSONRPC_BATCH_SIZE = _envs.create_env("MAX_JSONRPC_BATCH_SIZE", int, default=500) +# Maximum amount of requests per second +REQUESTS_PER_SECOND = _envs.create_env("REQUESTS_PER_SECOND", int, default=50) # Enable Demo Mode? demo_mode = _envs._deprecated_format.create_env("DEMO_MODE", bool, default=False, verbose=False) diff --git a/dank_mids/controller.py b/dank_mids/controller.py index 9af54432..a7705cf3 100644 --- a/dank_mids/controller.py +++ b/dank_mids/controller.py @@ -20,7 +20,9 @@ from dank_mids._exceptions import DankMidsInternalError from dank_mids._requests import JSONRPCBatch, Multicall, RPCRequest, eth_call from dank_mids._uid import UIDGenerator, _AlertingRLock -from dank_mids.helpers import _codec, _helpers, _session +from dank_mids.helpers._codec import decode_raw +from dank_mids.helpers._helpers import w3_version_major, _make_hashable, _sync_w3_from_async +from dank_mids.helpers._session import post, rate_limit_inactive from dank_mids.semaphores import _MethodQueues, _MethodSemaphores, BlockSemaphore from dank_mids.types import BlockId, PartialRequest, RawResponse, Request @@ -57,7 +59,7 @@ def __init__(self, w3: Web3) -> None: self.w3: Web3 = w3 """The Web3 instance used to make rpc requests.""" - self.sync_w3 = _helpers._sync_w3_from_async(w3) + self.sync_w3 = _sync_w3_from_async(w3) """A sync Web3 instance connected to the same rpc, used to make calls during init.""" self.chain_id = self.sync_w3.eth.chain_id @@ -192,6 +194,8 @@ async def __call__(self, method: RPCEndpoint, params: Any) -> RPCResponse: The response from the RPC call. """ + await rate_limit_inactive(self.endpoint) + # eth_call go thru a specialized Semaphore and other methods pass thru unblocked if method == "eth_call": async with self.eth_call_semaphores[params[1]]: @@ -216,7 +220,7 @@ async def __call__(self, method: RPCEndpoint, params: Any) -> RPCResponse: except TypeError as e: if "unhashable type" not in str(e): raise - return await queue(self, method, _helpers._make_hashable(params)) + return await queue(self, method, _make_hashable(params)) @eth_retry.auto_retry async def make_request( @@ -243,7 +247,7 @@ async def make_request( method=method, params=params, id=request_id or self.call_uid.next ) try: - return await _session.post(self.endpoint, data=request, loads=_codec.decode_raw) + return await post(self.endpoint, data=request, loads=decode_raw) except Exception as e: if ENVS.DEBUG: _debugging.failures.record( @@ -412,7 +416,7 @@ def _select_mcall_target_for_block(self, block) -> "_MulticallContract": @eth_retry.auto_retry def _get_client_version(sync_w3: Web3) -> str: - return sync_w3.client_version if _helpers.w3_version_major >= 6 else sync_w3.clientVersion # type: ignore [attr-defined] + return sync_w3.client_version if w3_version_major >= 6 else sync_w3.clientVersion # type: ignore [attr-defined] class _MulticallContract(Struct): diff --git a/dank_mids/helpers/_session.py b/dank_mids/helpers/_session.py index ffa394b3..9a4e5f1a 100644 --- a/dank_mids/helpers/_session.py +++ b/dank_mids/helpers/_session.py @@ -1,23 +1,25 @@ import http -import logging from asyncio import sleep +from collections import defaultdict from enum import IntEnum from itertools import chain +from logging import DEBUG, getLogger from random import random from threading import get_ident from time import time -from typing import Any, Callable, overload +from typing import Any, Callable, Tuple, overload +from a_sync import Event from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp.client_exceptions import ClientResponseError from aiohttp.typedefs import DEFAULT_JSON_DECODER, JSONDecoder from aiolimiter import AsyncLimiter from async_lru import alru_cache -from dank_mids import ENVIRONMENT_VARIABLES +from dank_mids import ENVIRONMENT_VARIABLES as ENVS from dank_mids.helpers._codec import encode from dank_mids.types import JSONRPCBatchResponse, PartialRequest, RawResponse -logger = logging.getLogger("dank_mids.session") +logger = getLogger("dank_mids.session") # NOTE: You cannot subclass an IntEnum object so we have to do some hacky shit here. @@ -102,7 +104,39 @@ def __new__(cls, value, phrase, description=""): HTTPStatusExtended.CLOUDFLARE_TIMEOUT, # type: ignore [attr-defined] } -limiter = AsyncLimiter(5, 0.1) # 50 requests/second + +# default is 50 requests/second +limiters = defaultdict(lambda: AsyncLimiter(ENVS.REQUESTS_PER_SECOND, 1)) + +_rate_limit_waiters = {} + + +async def rate_limit_inactive(endpoint: str) -> None: + # wait until the last future has been cleared from the rate limiter + if not (waiters := limiters[endpoint]._waiters): + return + + if waiter := _rate_limit_waiters.get(endpoint): + await waiter.wait() + return + + _rate_limit_waiters[endpoint] = Event() + + # pop last item + last_key, last_waiter = waiters.popitem() + # replace it + waiters[last_key] = last_waiter + # await it + await last_waiter + while waiters: + # pop last item + last_key, last_waiter = waiters.popitem() + # replace it + waiters[last_key] = last_waiter + # await it + await last_waiter + + _rate_limit_waiters.pop(endpoint).set() @overload @@ -136,48 +170,73 @@ async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DEC await sleep(self._continue_requests_at - now) # Process input arguments. - if isinstance(kwargs.get("data"), PartialRequest): - _logger_debug("making request for %s", kwargs["data"]) - kwargs["data"] = encode(kwargs["data"]) - _logger_debug("making request to %s with (args, kwargs): (%s %s)", endpoint, args, kwargs) + data = kwargs.get("data") + if debug_logs_enabled := _logger_is_enabled_for(DEBUG): + if isinstance(data, PartialRequest): + kwargs["data"] = encode(data) + _logger_log(DEBUG, "making request for %s", (data,)) + _logger_log( + DEBUG, + "making request to %s with (args, kwargs): (%s %s)", + (endpoint, args, kwargs), + ) + elif isinstance(data, PartialRequest): + kwargs["data"] = encode(data) # Try the request until success or 5 failures. tried = 0 while True: try: - async with limiter: + async with limiters[endpoint]: async with ClientSession.post(self, endpoint, *args, **kwargs) as response: response_data = await response.json(loads=loads, content_type=None) _logger_debug("received response %s", response_data) return response_data except ClientResponseError as ce: if ce.status == HTTPStatusExtended.TOO_MANY_REQUESTS: # type: ignore [attr-defined] - await self.handle_too_many_requests(ce) + await self.handle_too_many_requests(endpoint, ce) else: try: if ce.status not in RETRY_FOR_CODES or tried >= 5: _logger_debug( - "response failed with status %s", HTTPStatusExtended(ce.status) + "response failed with status %s", + HTTPStatusExtended(ce.status), ) raise ce except ValueError as ve: - raise ( - ce if str(ve).endswith("is not a valid HTTPStatusExtended") else ve - ) from ve - - sleep_time = random() - await sleep(sleep_time) - _logger_debug( - "response failed with status %s, retrying in %ss", - HTTPStatusExtended(ce.status), - round(sleep_time, 2), - ) - tried += 1 - - async def handle_too_many_requests(self, error: ClientResponseError) -> None: + if str(ve).endswith("is not a valid HTTPStatusExtended"): + raise ce from ve + raise + else: + tried += 1 + if debug_logs_enabled: + sleep_for = random() + _logger_log( + DEBUG, + "response failed with status %s, retrying in %ss", + (HTTPStatusExtended(ce.status), round(sleep_for, 2)), + ) + await sleep(sleep_for) + else: + await sleep(random()) + + async def handle_too_many_requests(self, endpoint: str, error: ClientResponseError) -> None: + limiter = limiters[endpoint] + if (now := time()) > getattr(limiter, "_last_updated_at", 0) + 60: + current_rate = limiter._rate_per_sec + new_rate = current_rate * 0.99 + limiter._rate_per_sec = new_rate + limiter._last_updated_at = now + _logger_info( + "reduced requests per second for %s from %s to %s", + endpoint, + current_rate, + new_rate, + ) + now = time() self._last_rate_limited_at = now - retry_after = float(error.headers.get("Retry-After", _RETRY_AFTER)) + retry_after = float(error.headers.get("Retry-After", 1 / limiter._rate_per_sec)) resume_at = max( self._continue_requests_at + retry_after, self._last_rate_limited_at + retry_after, @@ -186,22 +245,21 @@ async def handle_too_many_requests(self, error: ClientResponseError) -> None: self._continue_requests_at = resume_at self._log_rate_limited(retry_after) - await sleep(retry_after) - if retry_after > 30: - logger.warning("severe rate limiting from your provider") + _logger_warning("severe rate limiting from your provider") + await sleep(retry_after) def _log_rate_limited(self, try_after: float) -> None: if not self._limited: self._limited = True - logger.info("You're being rate limited by your node provider") - logger.info( + _logger_info("You're being rate limited by your node provider") + _logger_info( "Its all good, dank_mids has this handled, but you might get results slower than you'd like" ) if try_after < 5: _logger_debug("rate limited: retrying after %.3fs", try_after) else: - logger.info("rate limited: retrying after %.3fs", try_after) + _logger_info("rate limited: retrying after %.3fs", try_after) @alru_cache(maxsize=None) @@ -213,7 +271,7 @@ async def _get_session_for_thread(thread_ident: int) -> DankClientSession: # I'm testing the value to use for limit, eventually will make an env var for this with an appropriate default connector = TCPConnector(limit=0, enable_cleanup_closed=True) client_timeout = ClientTimeout( # type: ignore [arg-type, attr-defined] - int(ENVIRONMENT_VARIABLES.AIOHTTP_TIMEOUT) + int(ENVS.AIOHTTP_TIMEOUT) ) return DankClientSession( connector=connector, @@ -224,7 +282,15 @@ async def _get_session_for_thread(thread_ident: int) -> DankClientSession: ) +def _logger_is_enabled_for(msg: str, *args: Any) -> None: ... +def _logger_warning(msg: str, *args: Any) -> None: ... +def _logger_info(msg: str, *args: Any) -> None: ... def _logger_debug(msg: str, *args: Any) -> None: ... +def _logger_log(level: int, msg: str, args: Tuple[Any, ...]) -> None: ... +_logger_is_enabled_for = logger.isEnabledFor +_logger_warning = logger.warning +_logger_info = logger.info _logger_debug = logger.debug +_logger_log = logger._log