Skip to content

Commit

Permalink
feat: REQUESTS_PER_SECOND env var and auto-adjusting limit (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Dec 19, 2024
1 parent a4d1319 commit b245acc
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 39 deletions.
2 changes: 2 additions & 0 deletions dank_mids/ENVIRONMENT_VARIABLES.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
MAX_MULTICALL_SIZE = _envs.create_env("MAX_MULTICALL_SIZE", int, default=10_000)
# Max number of rpc calls to include in one batch call
MAX_JSONRPC_BATCH_SIZE = _envs.create_env("MAX_JSONRPC_BATCH_SIZE", int, default=500)
# Maximum amount of requests per second
REQUESTS_PER_SECOND = _envs.create_env("REQUESTS_PER_SECOND", int, default=50)

# Enable Demo Mode?
demo_mode = _envs._deprecated_format.create_env("DEMO_MODE", bool, default=False, verbose=False)
Expand Down
14 changes: 9 additions & 5 deletions dank_mids/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from dank_mids._exceptions import DankMidsInternalError
from dank_mids._requests import JSONRPCBatch, Multicall, RPCRequest, eth_call
from dank_mids._uid import UIDGenerator, _AlertingRLock
from dank_mids.helpers import _codec, _helpers, _session
from dank_mids.helpers._codec import decode_raw
from dank_mids.helpers._helpers import w3_version_major, _make_hashable, _sync_w3_from_async
from dank_mids.helpers._session import post, rate_limit_inactive
from dank_mids.semaphores import _MethodQueues, _MethodSemaphores, BlockSemaphore
from dank_mids.types import BlockId, PartialRequest, RawResponse, Request

Expand Down Expand Up @@ -57,7 +59,7 @@ def __init__(self, w3: Web3) -> None:
self.w3: Web3 = w3
"""The Web3 instance used to make rpc requests."""

self.sync_w3 = _helpers._sync_w3_from_async(w3)
self.sync_w3 = _sync_w3_from_async(w3)
"""A sync Web3 instance connected to the same rpc, used to make calls during init."""

self.chain_id = self.sync_w3.eth.chain_id
Expand Down Expand Up @@ -192,6 +194,8 @@ async def __call__(self, method: RPCEndpoint, params: Any) -> RPCResponse:
The response from the RPC call.
"""

await rate_limit_inactive(self.endpoint)

# eth_call go thru a specialized Semaphore and other methods pass thru unblocked
if method == "eth_call":
async with self.eth_call_semaphores[params[1]]:
Expand All @@ -216,7 +220,7 @@ async def __call__(self, method: RPCEndpoint, params: Any) -> RPCResponse:
except TypeError as e:
if "unhashable type" not in str(e):
raise
return await queue(self, method, _helpers._make_hashable(params))
return await queue(self, method, _make_hashable(params))

@eth_retry.auto_retry
async def make_request(
Expand All @@ -243,7 +247,7 @@ async def make_request(
method=method, params=params, id=request_id or self.call_uid.next
)
try:
return await _session.post(self.endpoint, data=request, loads=_codec.decode_raw)
return await post(self.endpoint, data=request, loads=decode_raw)
except Exception as e:
if ENVS.DEBUG:
_debugging.failures.record(
Expand Down Expand Up @@ -412,7 +416,7 @@ def _select_mcall_target_for_block(self, block) -> "_MulticallContract":

@eth_retry.auto_retry
def _get_client_version(sync_w3: Web3) -> str:
return sync_w3.client_version if _helpers.w3_version_major >= 6 else sync_w3.clientVersion # type: ignore [attr-defined]
return sync_w3.client_version if w3_version_major >= 6 else sync_w3.clientVersion # type: ignore [attr-defined]


class _MulticallContract(Struct):
Expand Down
134 changes: 100 additions & 34 deletions dank_mids/helpers/_session.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import http
import logging
from asyncio import sleep
from collections import defaultdict
from enum import IntEnum
from itertools import chain
from logging import DEBUG, getLogger
from random import random
from threading import get_ident
from time import time
from typing import Any, Callable, overload
from typing import Any, Callable, Tuple, overload

from a_sync import Event
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientResponseError
from aiohttp.typedefs import DEFAULT_JSON_DECODER, JSONDecoder
from aiolimiter import AsyncLimiter
from async_lru import alru_cache
from dank_mids import ENVIRONMENT_VARIABLES
from dank_mids import ENVIRONMENT_VARIABLES as ENVS
from dank_mids.helpers._codec import encode
from dank_mids.types import JSONRPCBatchResponse, PartialRequest, RawResponse

logger = logging.getLogger("dank_mids.session")
logger = getLogger("dank_mids.session")


# NOTE: You cannot subclass an IntEnum object so we have to do some hacky shit here.
Expand Down Expand Up @@ -102,7 +104,39 @@ def __new__(cls, value, phrase, description=""):
HTTPStatusExtended.CLOUDFLARE_TIMEOUT, # type: ignore [attr-defined]
}

limiter = AsyncLimiter(5, 0.1) # 50 requests/second

# default is 50 requests/second
limiters = defaultdict(lambda: AsyncLimiter(ENVS.REQUESTS_PER_SECOND, 1))

_rate_limit_waiters = {}


async def rate_limit_inactive(endpoint: str) -> None:
# wait until the last future has been cleared from the rate limiter
if not (waiters := limiters[endpoint]._waiters):
return

if waiter := _rate_limit_waiters.get(endpoint):
await waiter.wait()
return

_rate_limit_waiters[endpoint] = Event()

# pop last item
last_key, last_waiter = waiters.popitem()
# replace it
waiters[last_key] = last_waiter
# await it
await last_waiter
while waiters:
# pop last item
last_key, last_waiter = waiters.popitem()
# replace it
waiters[last_key] = last_waiter
# await it
await last_waiter

_rate_limit_waiters.pop(endpoint).set()


@overload
Expand Down Expand Up @@ -136,48 +170,73 @@ async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DEC
await sleep(self._continue_requests_at - now)

# Process input arguments.
if isinstance(kwargs.get("data"), PartialRequest):
_logger_debug("making request for %s", kwargs["data"])
kwargs["data"] = encode(kwargs["data"])
_logger_debug("making request to %s with (args, kwargs): (%s %s)", endpoint, args, kwargs)
data = kwargs.get("data")
if debug_logs_enabled := _logger_is_enabled_for(DEBUG):
if isinstance(data, PartialRequest):
kwargs["data"] = encode(data)
_logger_log(DEBUG, "making request for %s", (data,))
_logger_log(
DEBUG,
"making request to %s with (args, kwargs): (%s %s)",
(endpoint, args, kwargs),
)
elif isinstance(data, PartialRequest):
kwargs["data"] = encode(data)

# Try the request until success or 5 failures.
tried = 0
while True:
try:
async with limiter:
async with limiters[endpoint]:
async with ClientSession.post(self, endpoint, *args, **kwargs) as response:
response_data = await response.json(loads=loads, content_type=None)
_logger_debug("received response %s", response_data)
return response_data
except ClientResponseError as ce:
if ce.status == HTTPStatusExtended.TOO_MANY_REQUESTS: # type: ignore [attr-defined]
await self.handle_too_many_requests(ce)
await self.handle_too_many_requests(endpoint, ce)
else:
try:
if ce.status not in RETRY_FOR_CODES or tried >= 5:
_logger_debug(
"response failed with status %s", HTTPStatusExtended(ce.status)
"response failed with status %s",
HTTPStatusExtended(ce.status),
)
raise ce
except ValueError as ve:
raise (
ce if str(ve).endswith("is not a valid HTTPStatusExtended") else ve
) from ve

sleep_time = random()
await sleep(sleep_time)
_logger_debug(
"response failed with status %s, retrying in %ss",
HTTPStatusExtended(ce.status),
round(sleep_time, 2),
)
tried += 1

async def handle_too_many_requests(self, error: ClientResponseError) -> None:
if str(ve).endswith("is not a valid HTTPStatusExtended"):
raise ce from ve
raise
else:
tried += 1
if debug_logs_enabled:
sleep_for = random()
_logger_log(
DEBUG,
"response failed with status %s, retrying in %ss",
(HTTPStatusExtended(ce.status), round(sleep_for, 2)),
)
await sleep(sleep_for)
else:
await sleep(random())

async def handle_too_many_requests(self, endpoint: str, error: ClientResponseError) -> None:
limiter = limiters[endpoint]
if (now := time()) > getattr(limiter, "_last_updated_at", 0) + 60:
current_rate = limiter._rate_per_sec
new_rate = current_rate * 0.99
limiter._rate_per_sec = new_rate
limiter._last_updated_at = now
_logger_info(
"reduced requests per second for %s from %s to %s",
endpoint,
current_rate,
new_rate,
)

now = time()
self._last_rate_limited_at = now
retry_after = float(error.headers.get("Retry-After", _RETRY_AFTER))
retry_after = float(error.headers.get("Retry-After", 1 / limiter._rate_per_sec))
resume_at = max(
self._continue_requests_at + retry_after,
self._last_rate_limited_at + retry_after,
Expand All @@ -186,22 +245,21 @@ async def handle_too_many_requests(self, error: ClientResponseError) -> None:
self._continue_requests_at = resume_at

self._log_rate_limited(retry_after)
await sleep(retry_after)

if retry_after > 30:
logger.warning("severe rate limiting from your provider")
_logger_warning("severe rate limiting from your provider")
await sleep(retry_after)

def _log_rate_limited(self, try_after: float) -> None:
if not self._limited:
self._limited = True
logger.info("You're being rate limited by your node provider")
logger.info(
_logger_info("You're being rate limited by your node provider")
_logger_info(
"Its all good, dank_mids has this handled, but you might get results slower than you'd like"
)
if try_after < 5:
_logger_debug("rate limited: retrying after %.3fs", try_after)
else:
logger.info("rate limited: retrying after %.3fs", try_after)
_logger_info("rate limited: retrying after %.3fs", try_after)


@alru_cache(maxsize=None)
Expand All @@ -213,7 +271,7 @@ async def _get_session_for_thread(thread_ident: int) -> DankClientSession:
# I'm testing the value to use for limit, eventually will make an env var for this with an appropriate default
connector = TCPConnector(limit=0, enable_cleanup_closed=True)
client_timeout = ClientTimeout( # type: ignore [arg-type, attr-defined]
int(ENVIRONMENT_VARIABLES.AIOHTTP_TIMEOUT)
int(ENVS.AIOHTTP_TIMEOUT)
)
return DankClientSession(
connector=connector,
Expand All @@ -224,7 +282,15 @@ async def _get_session_for_thread(thread_ident: int) -> DankClientSession:
)


def _logger_is_enabled_for(msg: str, *args: Any) -> None: ...
def _logger_warning(msg: str, *args: Any) -> None: ...
def _logger_info(msg: str, *args: Any) -> None: ...
def _logger_debug(msg: str, *args: Any) -> None: ...
def _logger_log(level: int, msg: str, args: Tuple[Any, ...]) -> None: ...


_logger_is_enabled_for = logger.isEnabledFor
_logger_warning = logger.warning
_logger_info = logger.info
_logger_debug = logger.debug
_logger_log = logger._log

0 comments on commit b245acc

Please sign in to comment.