From a4d131949448f4e61613172cb56aba7ab3eb5b74 Mon Sep 17 00:00:00 2001 From: BobTheBuidler <70677534+BobTheBuidler@users.noreply.github.com> Date: Thu, 19 Dec 2024 14:42:18 -0400 Subject: [PATCH] feat: remove max connections limit (#322) * feat: remove max connections limit * chore: `black .` --------- Co-authored-by: github-actions[bot] --- dank_mids/_batch.py | 4 +-- dank_mids/brownie_patch/_method.py | 14 ++++----- dank_mids/controller.py | 10 +++++-- dank_mids/eth.py | 6 ++-- dank_mids/helpers/_helpers.py | 4 +-- dank_mids/helpers/_session.py | 48 ++++++++++++++++++------------ dank_mids/stats.py | 6 ++-- tests/test_brownie_patch.py | 4 +-- tests/test_dank_mids.py | 9 +++--- 9 files changed, 60 insertions(+), 45 deletions(-) diff --git a/dank_mids/_batch.py b/dank_mids/_batch.py index 6db9a728..128dc24a 100644 --- a/dank_mids/_batch.py +++ b/dank_mids/_batch.py @@ -1,5 +1,5 @@ -import asyncio import logging +from asyncio import gather from typing import TYPE_CHECKING, Any, Awaitable, Generator, List, Union from dank_mids._exceptions import DankMidsInternalError @@ -78,7 +78,7 @@ async def _await(self) -> None: it will be re-raised after all coroutines have been processed. """ batches = tuple(self.coroutines) - for batch, result in zip(batches, await asyncio.gather(*batches, return_exceptions=True)): + for batch, result in zip(batches, await gather(*batches, return_exceptions=True)): if isinstance(result, Exception): if not isinstance(result, DankMidsInternalError): logger.error( diff --git a/dank_mids/brownie_patch/_method.py b/dank_mids/brownie_patch/_method.py index 1fa5ce5f..2f29415a 100644 --- a/dank_mids/brownie_patch/_method.py +++ b/dank_mids/brownie_patch/_method.py @@ -1,6 +1,6 @@ -import asyncio -import functools +from asyncio import gather from decimal import Decimal +from functools import cached_property from typing import Any, Awaitable, Callable, Dict, Generic, Iterable, List, Optional, TypeVar from brownie.convert.datatypes import EthAddress @@ -68,7 +68,7 @@ async def map( self.coroutine(call_arg, block_identifier=block_identifier, decimals=decimals) for call_arg in args ] - return await asyncio.gather(*coros) + return await gather(*coros) @property def abi(self) -> dict: @@ -116,23 +116,23 @@ async def coroutine( def _input_sig(self) -> str: return self._abi.input_sig - @functools.cached_property + @cached_property def _len_inputs(self) -> int: return len(self.abi["inputs"]) - @functools.cached_property + @cached_property def _skip_decoder_proc_pool(self) -> bool: from dank_mids.brownie_patch.call import _skip_proc_pool return self._address in _skip_proc_pool - @functools.cached_property + @cached_property def _call(cls) -> DankWeb3: from dank_mids import web3 return web3.eth.call - @functools.cached_property + @cached_property def _prep_request_data(self) -> Callable[..., Awaitable[BytesLike]]: from dank_mids.brownie_patch import call diff --git a/dank_mids/controller.py b/dank_mids/controller.py index a4df3ff5..9af54432 100644 --- a/dank_mids/controller.py +++ b/dank_mids/controller.py @@ -196,14 +196,14 @@ async def __call__(self, method: RPCEndpoint, params: Any) -> RPCResponse: if method == "eth_call": async with self.eth_call_semaphores[params[1]]: # create a strong ref to the call that will be held until the caller completes or is cancelled - logger.debug(f"making {self.request_type.__name__} {method} with params {params}") + _logger_debug(f"making {self.request_type.__name__} {method} with params {params}") if params[0]["to"] in self.no_multicall: return await RPCRequest(self, method, params) return await eth_call(self, params) # some methods go thru a SmartProcessingQueue, we check those next queue = self.method_queues[method] - logger.debug(f"making {self.request_type.__name__} {method} with params {params}") + _logger_debug(f"making {self.request_type.__name__} {method} with params {params}") # no queue, we can make the request normally if queue is None: @@ -464,3 +464,9 @@ def __hash__(self) -> int: A hash value based on the contract's address. """ return hash(self.address) + + +def _logger_debug(msg: str, *args: Any) -> None: ... + + +_logger_debug = logger.debug diff --git a/dank_mids/eth.py b/dank_mids/eth.py index 6decec57..1e1c6ac3 100644 --- a/dank_mids/eth.py +++ b/dank_mids/eth.py @@ -1,4 +1,4 @@ -import asyncio +from asyncio import sleep from typing import ( Awaitable, Callable, @@ -194,7 +194,7 @@ async def trace_filter( # TODO: figure out a better way to handle intermittent errors while (traces_bytes := await self._trace_filter(filter_params)) is None and attempts < 5: attempts += 1 - await asyncio.sleep(1) + await sleep(1) try: return json.decode(traces_bytes, type=decode_to, dec_hook=decode_hook) except ValidationError: @@ -267,7 +267,7 @@ async def get_logs( # TODO: figure out a better way to handle intermittent errors while (logs_bytes := await self._get_logs_raw(*args, **kwargs)) is None and attempts < 5: # type: ignore [attr-defined] attempts += 1 - await asyncio.sleep(1) + await sleep(1) return json.decode(logs_bytes, type=decode_to, dec_hook=decode_hook) meth = MethodNoFormat.default(RPC.eth_getTransactionReceipt) # type: ignore [arg-type, var-annotated] diff --git a/dank_mids/helpers/_helpers.py b/dank_mids/helpers/_helpers.py index e1253ba4..f1f69e40 100644 --- a/dank_mids/helpers/_helpers.py +++ b/dank_mids/helpers/_helpers.py @@ -1,4 +1,4 @@ -import asyncio +from asyncio import as_completed from functools import wraps from importlib.metadata import version from typing import ( @@ -127,7 +127,7 @@ async def await_all(futs: Iterable[Awaitable]) -> None: Args: futs: An iterable of awaitables to be executed. """ - for fut in asyncio.as_completed([*futs]): + for fut in as_completed([*futs]): await fut del fut diff --git a/dank_mids/helpers/_session.py b/dank_mids/helpers/_session.py index 2e30065f..ffa394b3 100644 --- a/dank_mids/helpers/_session.py +++ b/dank_mids/helpers/_session.py @@ -1,21 +1,20 @@ -import asyncio import http import logging +from asyncio import sleep from enum import IntEnum from itertools import chain from random import random from threading import get_ident from time import time -from typing import Any, Callable, List, Optional, overload +from typing import Any, Callable, overload -import msgspec from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp.client_exceptions import ClientResponseError from aiohttp.typedefs import DEFAULT_JSON_DECODER, JSONDecoder from aiolimiter import AsyncLimiter from async_lru import alru_cache from dank_mids import ENVIRONMENT_VARIABLES -from dank_mids.helpers import _codec +from dank_mids.helpers._codec import encode from dank_mids.types import JSONRPCBatchResponse, PartialRequest, RawResponse logger = logging.getLogger("dank_mids.session") @@ -134,22 +133,22 @@ class DankClientSession(ClientSession): async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DECODER, **kwargs) -> bytes: # type: ignore [override] if (now := time()) < self._continue_requests_at: - await asyncio.sleep(self._continue_requests_at - now) + await sleep(self._continue_requests_at - now) # Process input arguments. if isinstance(kwargs.get("data"), PartialRequest): - logger.debug("making request for %s", kwargs["data"]) - kwargs["data"] = _codec.encode(kwargs["data"]) - logger.debug("making request to %s with (args, kwargs): (%s %s)", endpoint, args, kwargs) + _logger_debug("making request for %s", kwargs["data"]) + kwargs["data"] = encode(kwargs["data"]) + _logger_debug("making request to %s with (args, kwargs): (%s %s)", endpoint, args, kwargs) # Try the request until success or 5 failures. tried = 0 while True: try: async with limiter: - async with super().post(endpoint, *args, **kwargs) as response: + async with ClientSession.post(self, endpoint, *args, **kwargs) as response: response_data = await response.json(loads=loads, content_type=None) - logger.debug("received response %s", response_data) + _logger_debug("received response %s", response_data) return response_data except ClientResponseError as ce: if ce.status == HTTPStatusExtended.TOO_MANY_REQUESTS: # type: ignore [attr-defined] @@ -157,7 +156,7 @@ async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DEC else: try: if ce.status not in RETRY_FOR_CODES or tried >= 5: - logger.debug( + _logger_debug( "response failed with status %s", HTTPStatusExtended(ce.status) ) raise ce @@ -166,12 +165,12 @@ async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DEC ce if str(ve).endswith("is not a valid HTTPStatusExtended") else ve ) from ve - sleep = random() - await asyncio.sleep(sleep) - logger.debug( + sleep_time = random() + await sleep(sleep_time) + _logger_debug( "response failed with status %s, retrying in %ss", HTTPStatusExtended(ce.status), - round(sleep, 2), + round(sleep_time, 2), ) tried += 1 @@ -187,7 +186,7 @@ async def handle_too_many_requests(self, error: ClientResponseError) -> None: self._continue_requests_at = resume_at self._log_rate_limited(retry_after) - await asyncio.sleep(retry_after) + await sleep(retry_after) if retry_after > 30: logger.warning("severe rate limiting from your provider") @@ -200,7 +199,7 @@ def _log_rate_limited(self, try_after: float) -> None: "Its all good, dank_mids has this handled, but you might get results slower than you'd like" ) if try_after < 5: - logger.debug("rate limited: retrying after %.3fs", try_after) + _logger_debug("rate limited: retrying after %.3fs", try_after) else: logger.info("rate limited: retrying after %.3fs", try_after) @@ -211,10 +210,21 @@ async def _get_session_for_thread(thread_ident: int) -> DankClientSession: This makes our ClientSession threadsafe just in case. Most everything should be run in main thread though. """ + # I'm testing the value to use for limit, eventually will make an env var for this with an appropriate default + connector = TCPConnector(limit=0, enable_cleanup_closed=True) + client_timeout = ClientTimeout( # type: ignore [arg-type, attr-defined] + int(ENVIRONMENT_VARIABLES.AIOHTTP_TIMEOUT) + ) return DankClientSession( - connector=TCPConnector(limit=32), + connector=connector, headers={"content-type": "application/json"}, - timeout=ClientTimeout(ENVIRONMENT_VARIABLES.AIOHTTP_TIMEOUT), # type: ignore [arg-type, attr-defined] + timeout=client_timeout, raise_for_status=True, read_bufsize=2**20, # 1mb ) + + +def _logger_debug(msg: str, *args: Any) -> None: ... + + +_logger_debug = logger.debug diff --git a/dank_mids/stats.py b/dank_mids/stats.py index 66def17c..699cd17b 100644 --- a/dank_mids/stats.py +++ b/dank_mids/stats.py @@ -16,8 +16,8 @@ # TODO: Robust and Refactor -import asyncio import logging +from asyncio import create_task, sleep from collections import defaultdict, deque from concurrent.futures import ProcessPoolExecutor from time import time @@ -208,7 +208,7 @@ def _ensure_daemon(self) -> None: that occurred during its execution. """ if (ENVS.COLLECT_STATS or self.enabled) and self._daemon is None: # type: ignore [attr-defined,has-type] - self._daemon = asyncio.create_task(self._stats_daemon()) + self._daemon = create_task(self._stats_daemon()) elif self._daemon.done(): raise self._daemon.exception() # type: ignore [misc] @@ -220,7 +220,7 @@ async def _stats_daemon(self) -> None: start = time() time_since_notified = 0 while True: - await asyncio.sleep(0) + await sleep(0) now = time() duration = now - start collector.event_loop_times.append(duration) diff --git a/tests/test_brownie_patch.py b/tests/test_brownie_patch.py index cccf7141..07c1abdb 100644 --- a/tests/test_brownie_patch.py +++ b/tests/test_brownie_patch.py @@ -1,5 +1,5 @@ # sourcery skip: no-loop-in-tests -import asyncio +from asyncio import gather from decimal import Decimal import brownie @@ -31,7 +31,7 @@ async def test_gather(): weth = get_contract("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2") _patch_call(weth.totalSupply, dank_mids.web3) assert hasattr(weth.totalSupply, "coroutine") - for result in await asyncio.gather( + for result in await gather( *[weth.totalSupply.coroutine(block_identifier=13_000_000) for _ in range(10_000)] ): assert result == 6620041514474872981393155 diff --git a/tests/test_dank_mids.py b/tests/test_dank_mids.py index 618dafdc..34a57161 100644 --- a/tests/test_dank_mids.py +++ b/tests/test_dank_mids.py @@ -1,6 +1,5 @@ -import asyncio -import importlib import sys +from asyncio import gather import pytest from brownie import chain @@ -32,7 +31,7 @@ @pytest.mark.asyncio_cooperative async def test_dank_middleware(): - await asyncio.gather(*BIG_WORK) + await gather(*BIG_WORK) controller = instances[chain.id][0] cid = controller.call_uid.latest mid = controller.multicall_uid.latest @@ -76,7 +75,7 @@ async def test_json_batch(): This test verifies that the system can correctly handle and process a batch of JSON-RPC requests across multiple blocks. """ - await asyncio.gather(*MULTIBLOCK_WORK) + await gather(*MULTIBLOCK_WORK) def test_next_cid(): @@ -122,7 +121,7 @@ async def test_other_methods(): dank_web3.eth.get_block("0xe25822"), dank_web3.manager.coro_request(RPC.web3_clientVersion, []), ] - results = await asyncio.gather(*work) + results = await gather(*work) assert results assert results[-2].timestamp