Skip to content

Commit

Permalink
feat: remove max connections limit (#322)
Browse files Browse the repository at this point in the history
* feat: remove max connections limit

* chore: `black .`

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
BobTheBuidler and github-actions[bot] authored Dec 19, 2024
1 parent 8d91514 commit a4d1319
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 45 deletions.
4 changes: 2 additions & 2 deletions dank_mids/_batch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
import logging
from asyncio import gather
from typing import TYPE_CHECKING, Any, Awaitable, Generator, List, Union

from dank_mids._exceptions import DankMidsInternalError
Expand Down Expand Up @@ -78,7 +78,7 @@ async def _await(self) -> None:
it will be re-raised after all coroutines have been processed.
"""
batches = tuple(self.coroutines)
for batch, result in zip(batches, await asyncio.gather(*batches, return_exceptions=True)):
for batch, result in zip(batches, await gather(*batches, return_exceptions=True)):
if isinstance(result, Exception):
if not isinstance(result, DankMidsInternalError):
logger.error(
Expand Down
14 changes: 7 additions & 7 deletions dank_mids/brownie_patch/_method.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import functools
from asyncio import gather
from decimal import Decimal
from functools import cached_property
from typing import Any, Awaitable, Callable, Dict, Generic, Iterable, List, Optional, TypeVar

from brownie.convert.datatypes import EthAddress
Expand Down Expand Up @@ -68,7 +68,7 @@ async def map(
self.coroutine(call_arg, block_identifier=block_identifier, decimals=decimals)
for call_arg in args
]
return await asyncio.gather(*coros)
return await gather(*coros)

@property
def abi(self) -> dict:
Expand Down Expand Up @@ -116,23 +116,23 @@ async def coroutine(
def _input_sig(self) -> str:
return self._abi.input_sig

@functools.cached_property
@cached_property
def _len_inputs(self) -> int:
return len(self.abi["inputs"])

@functools.cached_property
@cached_property
def _skip_decoder_proc_pool(self) -> bool:
from dank_mids.brownie_patch.call import _skip_proc_pool

return self._address in _skip_proc_pool

@functools.cached_property
@cached_property
def _call(cls) -> DankWeb3:
from dank_mids import web3

return web3.eth.call

@functools.cached_property
@cached_property
def _prep_request_data(self) -> Callable[..., Awaitable[BytesLike]]:
from dank_mids.brownie_patch import call

Expand Down
10 changes: 8 additions & 2 deletions dank_mids/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,14 @@ async def __call__(self, method: RPCEndpoint, params: Any) -> RPCResponse:
if method == "eth_call":
async with self.eth_call_semaphores[params[1]]:
# create a strong ref to the call that will be held until the caller completes or is cancelled
logger.debug(f"making {self.request_type.__name__} {method} with params {params}")
_logger_debug(f"making {self.request_type.__name__} {method} with params {params}")
if params[0]["to"] in self.no_multicall:
return await RPCRequest(self, method, params)
return await eth_call(self, params)

# some methods go thru a SmartProcessingQueue, we check those next
queue = self.method_queues[method]
logger.debug(f"making {self.request_type.__name__} {method} with params {params}")
_logger_debug(f"making {self.request_type.__name__} {method} with params {params}")

# no queue, we can make the request normally
if queue is None:
Expand Down Expand Up @@ -464,3 +464,9 @@ def __hash__(self) -> int:
A hash value based on the contract's address.
"""
return hash(self.address)


def _logger_debug(msg: str, *args: Any) -> None: ...


_logger_debug = logger.debug
6 changes: 3 additions & 3 deletions dank_mids/eth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio
from asyncio import sleep
from typing import (
Awaitable,
Callable,
Expand Down Expand Up @@ -194,7 +194,7 @@ async def trace_filter(
# TODO: figure out a better way to handle intermittent errors
while (traces_bytes := await self._trace_filter(filter_params)) is None and attempts < 5:
attempts += 1
await asyncio.sleep(1)
await sleep(1)
try:
return json.decode(traces_bytes, type=decode_to, dec_hook=decode_hook)
except ValidationError:
Expand Down Expand Up @@ -267,7 +267,7 @@ async def get_logs(
# TODO: figure out a better way to handle intermittent errors
while (logs_bytes := await self._get_logs_raw(*args, **kwargs)) is None and attempts < 5: # type: ignore [attr-defined]
attempts += 1
await asyncio.sleep(1)
await sleep(1)
return json.decode(logs_bytes, type=decode_to, dec_hook=decode_hook)

meth = MethodNoFormat.default(RPC.eth_getTransactionReceipt) # type: ignore [arg-type, var-annotated]
Expand Down
4 changes: 2 additions & 2 deletions dank_mids/helpers/_helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio
from asyncio import as_completed
from functools import wraps
from importlib.metadata import version
from typing import (
Expand Down Expand Up @@ -127,7 +127,7 @@ async def await_all(futs: Iterable[Awaitable]) -> None:
Args:
futs: An iterable of awaitables to be executed.
"""
for fut in asyncio.as_completed([*futs]):
for fut in as_completed([*futs]):
await fut
del fut

Expand Down
48 changes: 29 additions & 19 deletions dank_mids/helpers/_session.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import asyncio
import http
import logging
from asyncio import sleep
from enum import IntEnum
from itertools import chain
from random import random
from threading import get_ident
from time import time
from typing import Any, Callable, List, Optional, overload
from typing import Any, Callable, overload

import msgspec
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientResponseError
from aiohttp.typedefs import DEFAULT_JSON_DECODER, JSONDecoder
from aiolimiter import AsyncLimiter
from async_lru import alru_cache
from dank_mids import ENVIRONMENT_VARIABLES
from dank_mids.helpers import _codec
from dank_mids.helpers._codec import encode
from dank_mids.types import JSONRPCBatchResponse, PartialRequest, RawResponse

logger = logging.getLogger("dank_mids.session")
Expand Down Expand Up @@ -134,30 +133,30 @@ class DankClientSession(ClientSession):

async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DECODER, **kwargs) -> bytes: # type: ignore [override]
if (now := time()) < self._continue_requests_at:
await asyncio.sleep(self._continue_requests_at - now)
await sleep(self._continue_requests_at - now)

# Process input arguments.
if isinstance(kwargs.get("data"), PartialRequest):
logger.debug("making request for %s", kwargs["data"])
kwargs["data"] = _codec.encode(kwargs["data"])
logger.debug("making request to %s with (args, kwargs): (%s %s)", endpoint, args, kwargs)
_logger_debug("making request for %s", kwargs["data"])
kwargs["data"] = encode(kwargs["data"])
_logger_debug("making request to %s with (args, kwargs): (%s %s)", endpoint, args, kwargs)

# Try the request until success or 5 failures.
tried = 0
while True:
try:
async with limiter:
async with super().post(endpoint, *args, **kwargs) as response:
async with ClientSession.post(self, endpoint, *args, **kwargs) as response:
response_data = await response.json(loads=loads, content_type=None)
logger.debug("received response %s", response_data)
_logger_debug("received response %s", response_data)
return response_data
except ClientResponseError as ce:
if ce.status == HTTPStatusExtended.TOO_MANY_REQUESTS: # type: ignore [attr-defined]
await self.handle_too_many_requests(ce)
else:
try:
if ce.status not in RETRY_FOR_CODES or tried >= 5:
logger.debug(
_logger_debug(
"response failed with status %s", HTTPStatusExtended(ce.status)
)
raise ce
Expand All @@ -166,12 +165,12 @@ async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DEC
ce if str(ve).endswith("is not a valid HTTPStatusExtended") else ve
) from ve

sleep = random()
await asyncio.sleep(sleep)
logger.debug(
sleep_time = random()
await sleep(sleep_time)
_logger_debug(
"response failed with status %s, retrying in %ss",
HTTPStatusExtended(ce.status),
round(sleep, 2),
round(sleep_time, 2),
)
tried += 1

Expand All @@ -187,7 +186,7 @@ async def handle_too_many_requests(self, error: ClientResponseError) -> None:
self._continue_requests_at = resume_at

self._log_rate_limited(retry_after)
await asyncio.sleep(retry_after)
await sleep(retry_after)

if retry_after > 30:
logger.warning("severe rate limiting from your provider")
Expand All @@ -200,7 +199,7 @@ def _log_rate_limited(self, try_after: float) -> None:
"Its all good, dank_mids has this handled, but you might get results slower than you'd like"
)
if try_after < 5:
logger.debug("rate limited: retrying after %.3fs", try_after)
_logger_debug("rate limited: retrying after %.3fs", try_after)
else:
logger.info("rate limited: retrying after %.3fs", try_after)

Expand All @@ -211,10 +210,21 @@ async def _get_session_for_thread(thread_ident: int) -> DankClientSession:
This makes our ClientSession threadsafe just in case.
Most everything should be run in main thread though.
"""
# I'm testing the value to use for limit, eventually will make an env var for this with an appropriate default
connector = TCPConnector(limit=0, enable_cleanup_closed=True)
client_timeout = ClientTimeout( # type: ignore [arg-type, attr-defined]
int(ENVIRONMENT_VARIABLES.AIOHTTP_TIMEOUT)
)
return DankClientSession(
connector=TCPConnector(limit=32),
connector=connector,
headers={"content-type": "application/json"},
timeout=ClientTimeout(ENVIRONMENT_VARIABLES.AIOHTTP_TIMEOUT), # type: ignore [arg-type, attr-defined]
timeout=client_timeout,
raise_for_status=True,
read_bufsize=2**20, # 1mb
)


def _logger_debug(msg: str, *args: Any) -> None: ...


_logger_debug = logger.debug
6 changes: 3 additions & 3 deletions dank_mids/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

# TODO: Robust and Refactor

import asyncio
import logging
from asyncio import create_task, sleep
from collections import defaultdict, deque
from concurrent.futures import ProcessPoolExecutor
from time import time
Expand Down Expand Up @@ -208,7 +208,7 @@ def _ensure_daemon(self) -> None:
that occurred during its execution.
"""
if (ENVS.COLLECT_STATS or self.enabled) and self._daemon is None: # type: ignore [attr-defined,has-type]
self._daemon = asyncio.create_task(self._stats_daemon())
self._daemon = create_task(self._stats_daemon())
elif self._daemon.done():
raise self._daemon.exception() # type: ignore [misc]

Expand All @@ -220,7 +220,7 @@ async def _stats_daemon(self) -> None:
start = time()
time_since_notified = 0
while True:
await asyncio.sleep(0)
await sleep(0)
now = time()
duration = now - start
collector.event_loop_times.append(duration)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_brownie_patch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# sourcery skip: no-loop-in-tests
import asyncio
from asyncio import gather
from decimal import Decimal

import brownie
Expand Down Expand Up @@ -31,7 +31,7 @@ async def test_gather():
weth = get_contract("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2")
_patch_call(weth.totalSupply, dank_mids.web3)
assert hasattr(weth.totalSupply, "coroutine")
for result in await asyncio.gather(
for result in await gather(
*[weth.totalSupply.coroutine(block_identifier=13_000_000) for _ in range(10_000)]
):
assert result == 6620041514474872981393155
Expand Down
9 changes: 4 additions & 5 deletions tests/test_dank_mids.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import importlib
import sys
from asyncio import gather

import pytest
from brownie import chain
Expand Down Expand Up @@ -32,7 +31,7 @@

@pytest.mark.asyncio_cooperative
async def test_dank_middleware():
await asyncio.gather(*BIG_WORK)
await gather(*BIG_WORK)
controller = instances[chain.id][0]
cid = controller.call_uid.latest
mid = controller.multicall_uid.latest
Expand Down Expand Up @@ -76,7 +75,7 @@ async def test_json_batch():
This test verifies that the system can correctly handle and process
a batch of JSON-RPC requests across multiple blocks.
"""
await asyncio.gather(*MULTIBLOCK_WORK)
await gather(*MULTIBLOCK_WORK)


def test_next_cid():
Expand Down Expand Up @@ -122,7 +121,7 @@ async def test_other_methods():
dank_web3.eth.get_block("0xe25822"),
dank_web3.manager.coro_request(RPC.web3_clientVersion, []),
]
results = await asyncio.gather(*work)
results = await gather(*work)
assert results
assert results[-2].timestamp

Expand Down

0 comments on commit a4d1319

Please sign in to comment.