diff --git a/.gitignore b/.gitignore index a6ededca..3301358a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ dank_mids.egg-info .pytest_cache .eggs .mypy_cache -__pycache__ \ No newline at end of file +__pycache__ +just-leave-me-here-and-dont-touch-me-plz.log \ No newline at end of file diff --git a/dank_mids/ENVIRONMENT_VARIABLES.py b/dank_mids/ENVIRONMENT_VARIABLES.py new file mode 100644 index 00000000..51667b7d --- /dev/null +++ b/dank_mids/ENVIRONMENT_VARIABLES.py @@ -0,0 +1,85 @@ + +import logging + +import a_sync +import typed_envs +from a_sync import AsyncProcessPoolExecutor + +from dank_mids import _envs +from dank_mids._mode import OperationMode +from dank_mids.semaphores import BlockSemaphore + +logger = logging.getLogger("dank_mids.envs") + +if not typed_envs.logger.disabled: + logger.info("For your information, you can tweak your configuration for optimal performance using any of the envs below:") + +############### +# ENVIRONMENT # +############### + +# What mode should dank mids operate in? +# NOTE: infura mode is required for now +# TODO: fix the other modes, set default='default', and make this verbose again +OPERATION_MODE = _envs.create_env("OPERATION_MODE", OperationMode, default="infura", verbose=False) + +# Max number of rpc calls to include in one batch call +MAX_JSONRPC_BATCH_SIZE = _envs.create_env("MAX_JSONRPC_BATCH_SIZE", int, default=500) + +# Enable Demo Mode? +demo_mode = _envs._deprecated_format.create_env("DEMO_MODE", bool, default=False, verbose=False) +DEMO_MODE = _envs.create_env("DEMO_MODE", bool, default=demo_mode, verbose=False) + +# Are you calling a ganache fork? Can't use state override code +ganache_fork = _envs._deprecated_format.create_env("GANACHE_FORK", bool, default=False, verbose=False) +GANACHE_FORK = _envs.create_env("GANACHE_FORK", bool, default=ganache_fork) + +# We set the default to 20 minutes to account for potentially long event loop times if you're doing serious work. +AIOHTTP_TIMEOUT = _envs.create_env("AIOHTTP_TIMEOUT", int, default=20*60, string_converter=int) + +# Brownie call Semaphore +# Used because I experienced some OOM errs due to web3 formatters when I was batching an absurd number of brownie calls. +# We need a separate semaphore here because the method-specific semaphores are too late in the code to prevent this OOM issue. +brownie_semaphore = _envs._deprecated_format.create_env("BROWNIE_CALL_SEMAPHORE", int, default=100_000, string_converter=int, verbose=False) +BROWNIE_CALL_SEMAPHORE = _envs.create_env("BROWNIE_CALL_SEMAPHORE", BlockSemaphore, default=brownie_semaphore, string_converter=int, verbose=not OPERATION_MODE.infura) +BROWNIE_ENCODER_SEMAPHORE = _envs.create_env("BROWNIE_ENCODER_SEMAPHORE", BlockSemaphore, default=BROWNIE_CALL_SEMAPHORE._default_value * 2, string_converter=int, verbose=not OPERATION_MODE.infura) + +# Processes for decoding. This determines process pool size, not total subprocess count. +# There are 3 pools, each initialized with the same value. +# NOTE: Don't stress, these are good for you and will not hog your cpu. You can disable them by setting the var = 0. #TODO: lol u cant yet +BROWNIE_ENCODER_PROCESSES = _envs.create_env("BROWNIE_ENCODER_PROCESSES", AsyncProcessPoolExecutor, default=0 if OPERATION_MODE.infura else 1, string_converter=int, verbose=not OPERATION_MODE.infura) +BROWNIE_DECODER_PROCESSES = _envs.create_env("BROWNIE_DECODER_PROCESSES", AsyncProcessPoolExecutor, default=0 if OPERATION_MODE.infura else 1, string_converter=int, verbose=not OPERATION_MODE.infura) +MULTICALL_DECODER_PROCESSES = _envs.create_env("MULTICALL_DECODER_PROCESSES", AsyncProcessPoolExecutor, default=0 if OPERATION_MODE.infura else 1, string_converter=int, verbose=not OPERATION_MODE.infura) + +# NOTE: EXPORT_STATS is not implemented +# TODO: implement this +EXPORT_STATS = _envs.create_env("EXPORT_STATS", bool, default=False, verbose=False) +# NOTE: COLLECT_STATS is implemented +COLLECT_STATS = _envs.create_env("COLLECT_STATS", bool, default=EXPORT_STATS, verbose=not EXPORT_STATS) + +# You probably don't need to use this unless you know you need to +STUCK_CALL_TIMEOUT = _envs.create_env("STUCK_CALL_TIMEOUT", int, default=60) + +# Method-specific Semaphores +method_semaphores = { + "eth_call": _envs.create_env("ETH_CALL_SEMAPHORE", BlockSemaphore, default=BROWNIE_CALL_SEMAPHORE._value, string_converter=int), + "eth_getBlock": _envs.create_env("ETH_GETBLOCK_SEMAPHORE", a_sync.Semaphore, default=50, string_converter=int), + "eth_getLogs": _envs.create_env("ETH_GETLOGS_SEMAPHORE", a_sync.Semaphore, default=64, string_converter=int), + "eth_getTransaction": _envs.create_env("ETH_GETTRANSACTION_SEMAPHORE", a_sync.Semaphore, default=100, string_converter=int), +} + +if not typed_envs.logger.disabled: + logger.info("More details can be found in dank_mids/ENVIRONMENT_VARIABLES.py") + logger.info("NOTE: You can disable these logs by setting the `TYPEDENVS_SHUTUP` env var to any value.") + + +# Validate some stuffs + +# NOTE: The other modes are (probably) bugging out right now. More investigation needed. For now you use infura mode. +if not OPERATION_MODE.infura: + raise ValueError("Dank mids must be run in infura mode for now") + +if OPERATION_MODE.infura: + for process_pool in {MULTICALL_DECODER_PROCESSES, BROWNIE_DECODER_PROCESSES, BROWNIE_ENCODER_PROCESSES}: + if process_pool._max_workers: + raise ValueError(f"You cannot set env var {process_pool.name} while running dank in infura mode.") diff --git a/dank_mids/__init__.py b/dank_mids/__init__.py index 961f4fed..df63ddf6 100644 --- a/dank_mids/__init__.py +++ b/dank_mids/__init__.py @@ -1,4 +1,15 @@ + + +from dank_mids._how_is_this_real import _the_most_absurd_fix_youve_ever_seen from dank_mids.controller import instances from dank_mids.helpers import setup_dank_w3, setup_dank_w3_from_sync from dank_mids.middleware import dank_middleware + + +def _configure_concurrent_future_work_queue_size(): + import concurrent.futures.process as _cfp + _cfp.EXTRA_QUEUED_CALLS = 50_000 + +_configure_concurrent_future_work_queue_size() +_the_most_absurd_fix_youve_ever_seen() diff --git a/dank_mids/_config.py b/dank_mids/_config.py deleted file mode 100644 index 9f7f1b09..00000000 --- a/dank_mids/_config.py +++ /dev/null @@ -1,35 +0,0 @@ - -import os - -from aiohttp import ClientTimeout - -LOOP_INTERVAL = float(os.environ.get("DANKMIDSLOOPINTERVAL", 0.01)) - -# Max number of rpc calls to include in one batch call -MAX_JSONRPC_BATCH_SIZE = int(os.environ.get("MAX_JSONRPC_BATCH_SIZE", 500)) - - -# Enable Demo Mode? -DEMO_MODE = bool(os.environ.get("DANKMIDS_DEMO_MODE", os.environ.get("DANK_MIDS_DEMO_MODE"))) - -# Are you calling a ganache fork? Can't use state override code -GANACHE_FORK = bool(os.environ.get("DANKMIDS_GANACHE_FORK", os.environ.get("DANK_MIDS_GANACHE_FORK"))) - -# With default AsyncBaseProvider settings, some dense calls will fail -# due to aiohttp.TimeoutError where they would otherwise succeed. -# We set the default to 2 minutes but if you're doing serious work -# you may want to increase it further. -AIOHTTP_TIMEOUT = ClientTimeout(int(os.environ.get("AIOHTTP_TIMEOUT", 120))) - - -# Method-specific Semaphores -semaphore_envs = { - "eth_getBlock": int(os.environ.get("ETH_GETBLOCK_SEMAPHORE", 50)), # [eth_getBlockByNumber, eth_getBlockByHash] - "eth_getLogs": int(os.environ.get("ETH_GETLOGS_SEMAPHORE", 16)), - "eth_getTransaction": int(os.environ.get("ETH_GETTRANSACTION_SEMAPHORE", 100)), -} - -# Brownie call Semaphore -# Used because I experienced some OOM errs due to web3 formatters when I was batching an absurd number of brownie calls. -# We need a separate semaphore here because the method-specific semaphores are too late in the code to prevent this OOM issue. -BROWNIE_CALL_SEMAPHORE_VAL = int(os.environ.get("DANK_MIDS_BROWNIE_CALL_SEMAPHORE", 100_000)) diff --git a/dank_mids/_demo_mode.py b/dank_mids/_demo_mode.py index b2a2f93f..aa039547 100644 --- a/dank_mids/_demo_mode.py +++ b/dank_mids/_demo_mode.py @@ -1,7 +1,7 @@ import logging from typing import Any -from dank_mids._config import DEMO_MODE +from dank_mids import ENVIRONMENT_VARIABLES class DummyLogger: @@ -9,4 +9,4 @@ class DummyLogger: def info(self, *args: Any, **kwargs: Any) -> None: ... -demo_logger = logging.getLogger("dank_mids.demo") if DEMO_MODE else DummyLogger() +demo_logger = logging.getLogger("dank_mids.demo") if ENVIRONMENT_VARIABLES.DEMO_MODE else DummyLogger() diff --git a/dank_mids/_envs.py b/dank_mids/_envs.py new file mode 100644 index 00000000..73ac0959 --- /dev/null +++ b/dank_mids/_envs.py @@ -0,0 +1,8 @@ + +from typed_envs import EnvVarFactory + +_factory = EnvVarFactory("DANKMIDS") +create_env = _factory.create_env + +# This only applies to the oldest of dank's envs +_deprecated_format = EnvVarFactory("DANK_MIDS") diff --git a/dank_mids/_exceptions.py b/dank_mids/_exceptions.py new file mode 100644 index 00000000..9ca24bb8 --- /dev/null +++ b/dank_mids/_exceptions.py @@ -0,0 +1,45 @@ + +from typing import TYPE_CHECKING, Union +import logging + +from aiohttp.client_exceptions import ClientResponseError + +if TYPE_CHECKING: + from dank_mids.types import PartialRequest, PartialResponse + + +logger = logging.getLogger("dank_mids.exceptions") + +class BadResponse(ValueError): + def __init__(self, response: "PartialResponse") -> None: + self.response = response + super().__init__(response.to_dict()) + +class EmptyBatch(ValueError): + pass + +class ResponseNotReady(ValueError): + pass + +class PayloadTooLarge(BadResponse): + pass + +class DankMidsClientResponseError(ClientResponseError): + """A wrapper around the standard aiohttp ClientResponseError that attaches the request that generated the error.""" + def __init__( + self, + exc: ClientResponseError, + request: "PartialRequest", + ) -> None: + self.request = request + super().__init__(exc.request_info, exc.history, code=exc.code, status=exc.status, message=exc.message, headers=exc.headers) + self.args = (*exc.request_info, exc.history, request) + self._exception = exc + +internal_err_types = Union[AttributeError, TypeError, UnboundLocalError, NotImplementedError, RuntimeError, SyntaxError] + +class DankMidsInternalError(Exception): + def __init__(self, e: internal_err_types) -> None: + logger.warning(f"unhandled exception inside dank mids internals: {e}", exc_info=True) + self._original_exception = e + super().__init__(e.__repr__()) diff --git a/dank_mids/_how_is_this_real.py b/dank_mids/_how_is_this_real.py new file mode 100644 index 00000000..fe2fbc3e --- /dev/null +++ b/dank_mids/_how_is_this_real.py @@ -0,0 +1,37 @@ + +import logging +from logging.handlers import RotatingFileHandler + +LOGS_PATH = 'just-leave-me-here-and-dont-touch-me-plz.log' +ONE_KB = 1024 + +silencer = RotatingFileHandler(LOGS_PATH, maxBytes=500*ONE_KB, backupCount=0) + +def _the_most_absurd_fix_youve_ever_seen(): + """ + Somehow this resolves a race condition enough for the library to work as intended. + How? I have no idea at all. Please don't ask me, its magic. + I'll debug this later once I've pooshed a stable release to prod. + """ + enable_logger_but_divert_stream('a_sync.abstract') + enable_logger_but_divert_stream('a_sync.base') + enable_logger_but_divert_stream('dank_mids.controller') + enable_logger_but_divert_stream('dank_mids.session') + enable_logger_but_divert_stream('dank_mids.semaphores') + +def enable_logger_but_divert_stream(name, show_every_x=None): + logger = logging.getLogger(name) + if logging.root.isEnabledFor(logging.DEBUG) or logger.isEnabledFor(logging.DEBUG): + # We don't need to do anything, the user has indicated they want the debug logs enabled and shown + return + # Break the logger apart from the root logger and its handlers + logger.propagate = False + # Enable the logger regardless of root settings + logger.setLevel(logging.DEBUG) + # remove the root handler that was added at basicConfig step (not sure if this is actually necessary) + logger.removeHandler(logging.StreamHandler()) + # ensure the logger has no handlers + assert logger.handlers == [], logger.handlers + # add the silencer handler to direct the logs to a throwaway file + logger.addHandler(silencer) + \ No newline at end of file diff --git a/dank_mids/_mode.py b/dank_mids/_mode.py new file mode 100644 index 00000000..5157900e --- /dev/null +++ b/dank_mids/_mode.py @@ -0,0 +1,22 @@ + +from functools import cached_property + +MODES = ["default", "application", "infura"] + +class OperationMode(str): + @cached_property + def application(self) -> bool: + # This mode keeps the event loop as unblocked as possible so an asyncio application can run as designed + return self.mode == "application" + @cached_property + def default(self) -> bool: + return self.mode == "default" + @cached_property + def infura(self) -> bool: + # This mode minimizes the total number of calls sent to the node + return self.mode == "infura" + @property + def mode(self) -> str: + if self not in MODES: + raise ValueError(f'dank mids operation mode {self} is invalid', f'valid modes: {MODES}', str(self)) + return self diff --git a/dank_mids/batch.py b/dank_mids/batch.py new file mode 100644 index 00000000..b4be4225 --- /dev/null +++ b/dank_mids/batch.py @@ -0,0 +1,77 @@ + +import asyncio +import logging +from typing import TYPE_CHECKING, Any, Generator, List + +from dank_mids._exceptions import DankMidsInternalError +from dank_mids.requests import JSONRPCBatch, RPCRequest, _Batch +from dank_mids.types import Multicalls + +if TYPE_CHECKING: + from dank_mids.controller import DankMiddlewareController + +MIN_SIZE = 1 # TODO: Play with this +CHECK = MIN_SIZE - 1 + +logger = logging.getLogger(__name__) + +class DankBatch: + """ A batch of jsonrpc batches. This is pretty much deprecated and needs to be refactored away.""" + def __init__(self, controller: "DankMiddlewareController", multicalls: Multicalls, rpc_calls: List[RPCRequest]): + self.controller = controller + self.multicalls = multicalls + self.rpc_calls = rpc_calls + self._started = False + + def __await__(self) -> Generator[Any, None, Any]: + self.start() + return self._await().__await__() + + async def _await(self) -> None: + batches = tuple(self.coroutines) + for batch, result in zip(batches, await asyncio.gather(*batches, return_exceptions=True)): + if isinstance(result, Exception): + if not isinstance(result, DankMidsInternalError): + logger.error(f"That's not good, there was an exception in a {batch.__class__.__name__}. These are supposed to be handled.\n{result}\n", exc_info=True) + raise result + + def start(self) -> None: + for mcall in self.multicalls.values(): + mcall.start(self, cleanup=False) + for call in self.rpc_calls: + call.start(self) + self._started = True + + @property + def coroutines(self) -> Generator["_Batch", None, None]: + # Combine multicalls into one or more jsonrpc batches + + # Create empty batch + working_batch = JSONRPCBatch(self.controller) + + check_len = min(CHECK, self.controller.batcher.step) + # Go thru the multicalls and add calls to the batch + for mcall in self.multicalls.values(): + # NOTE: If a multicall has less than `CHECK` calls, we should just throw the calls into a jsonrpc batch individually. + try: # NOTE: This should be faster than using len(). + mcall[check_len] + working_batch.append(mcall, skip_check=True) + except IndexError: + working_batch.extend(mcall, skip_check=True) + if working_batch.is_full: + yield working_batch + working_batch = JSONRPCBatch(self.controller) + + rpc_calls_to_batch = self.rpc_calls[:] + while rpc_calls_to_batch: + if working_batch.is_full: + yield working_batch + working_batch = JSONRPCBatch(self.controller) + working_batch.append(rpc_calls_to_batch.pop(), skip_check=True) + if working_batch: + if working_batch.is_single_multicall: + yield working_batch[0] # type: ignore [misc] + elif len(working_batch) == 1: + yield working_batch[0].make_request() + else: + yield working_batch diff --git a/dank_mids/brownie_patch/call.py b/dank_mids/brownie_patch/call.py index 5460d643..9f3a8b0c 100644 --- a/dank_mids/brownie_patch/call.py +++ b/dank_mids/brownie_patch/call.py @@ -1,62 +1,37 @@ -import functools +import logging +from concurrent.futures.process import BrokenProcessPool +from functools import lru_cache +from pickle import PicklingError from types import MethodType from typing import Any, Dict, Optional, Tuple, Union import eth_abi +from a_sync import AsyncProcessPoolExecutor from brownie.convert.normalize import format_input, format_output from brownie.convert.utils import get_type_strings from brownie.exceptions import VirtualMachineError -from brownie.network.contract import ContractCall +from brownie.network.contract import Contract, ContractCall from brownie.project.compiler.solidity import SOLIDITY_ERROR_CODES from hexbytes import HexBytes -from multicall.utils import run_in_subprocess from web3 import Web3 -from dank_mids._config import BROWNIE_CALL_SEMAPHORE_VAL -from dank_mids.semaphore import ThreadsafeSemaphore +from dank_mids import ENVIRONMENT_VARIABLES as ENVS -brownie_call_semaphore = ThreadsafeSemaphore(BROWNIE_CALL_SEMAPHORE_VAL) +logger = logging.getLogger(__name__) +encode = lambda self, *args: ENVS.BROWNIE_ENCODER_PROCESSES.run(__encode_input, self.abi, self.signature, *args) +decode = lambda self, data: ENVS.BROWNIE_DECODER_PROCESSES.run(__decode_output, data, self.abi) -def __encode_input(abi: Dict[str, Any], signature: str, *args: Tuple[Any,...]) -> str: - data = format_input(abi, args) - types_list = get_type_strings(abi["inputs"]) - return signature + eth_abi.encode_abi(types_list, data).hex() - -def __decode_output(hexstr: str, abi: Dict[str, Any]) -> Any: - selector = HexBytes(hexstr)[:4].hex() - if selector == "0x08c379a0": - revert_str = eth_abi.decode_abi(["string"], HexBytes(hexstr)[4:])[0] - raise ValueError(f"Call reverted: {revert_str}") - elif selector == "0x4e487b71": - error_code = int(HexBytes(hexstr)[4:].hex(), 16) - if error_code in SOLIDITY_ERROR_CODES: - revert_str = SOLIDITY_ERROR_CODES[error_code] - else: - revert_str = f"Panic (error code: {error_code})" - raise ValueError(f"Call reverted: {revert_str}") - if abi["outputs"] and not hexstr: - raise ValueError("No data was returned - the call likely reverted") - - types_list = get_type_strings(abi["outputs"]) - result = eth_abi.decode_abi(types_list, HexBytes(hexstr)) - result = format_output(abi, result) - if len(result) == 1: - result = result[0] - return result +def _patch_call(call: ContractCall, w3: Web3) -> None: + call.coroutine = MethodType(_get_coroutine_fn(w3, len(call.abi['inputs'])), call) -async def _encode_input(self: ContractCall, *args: Tuple[Any,...]) -> str: - return await run_in_subprocess( - __encode_input, - self.abi, - self.signature, - *(arg if not hasattr(arg, 'address') else arg.address for arg in args) # type: ignore - ) - -async def _decode_output(self: ContractCall, data: str) -> Any: - return await run_in_subprocess(__decode_output, data, self.abi) +@lru_cache +def _get_coroutine_fn(w3: Web3, len_inputs: int): + if ENVS.OPERATION_MODE.application: + get_request_data = encode + else: + get_request_data = encode if len_inputs else __request_data_no_args -def _patch_call(call: ContractCall, w3: Web3) -> None: async def coroutine( self: ContractCall, *args: Tuple[Any,...], @@ -65,18 +40,89 @@ async def coroutine( ) -> Any: if override: raise ValueError("Cannot use state override with `coroutine`.") - - async with brownie_call_semaphore: - try: - return await self._decode_output( - await w3.eth.call({"to": self._address, "data": await self._encode_input(*args)}, block_identifier) # type: ignore - ) - except ValueError as e: - try: - raise VirtualMachineError(e) from None - except: - raise e - - call.coroutine = MethodType(coroutine, call) - call._encode_input = MethodType(_encode_input, call) - call._decode_output = MethodType(_decode_output, call) + async with ENVS.BROWNIE_ENCODER_SEMAPHORE[block_identifier]: + data = await encode_input(self, len_inputs, get_request_data, *args) + async with ENVS.BROWNIE_CALL_SEMAPHORE[block_identifier]: + output = await w3.eth.call({"to": self._address, "data": data}, block_identifier) + return await decode_output(self, output) + + return coroutine + +async def encode_input(call: ContractCall, len_inputs, get_request_data, *args) -> bytes: + if any(isinstance(arg, Contract) for arg in args) or any(hasattr(arg, "__contains__") for arg in args): # We will just assume containers contain a Contract object until we have a better way to handle this + # We can't unpickle these because of the added `coroutine` method. + data = __encode_input(call.abi, call.signature, *args) + else: + try: # We're better off sending these to the subprocess so they don't clog up the event loop. + data = await get_request_data(call, *args) + # TODO: move this somewhere else + except BrokenProcessPool: + logger.critical("Oh fuck, you broke the %s while decoding %s with abi %s", ENVS.BROWNIE_ENCODER_PROCESSES, data, call.abi) + # Let's fix that right up + ENVS.BROWNIE_ENCODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.BROWNIE_ENCODER_PROCESSES._max_workers) + data = __encode_input(call.abi, call.signature, *args) if len_inputs else call.signature + except PicklingError: # But if that fails, don't worry. I got you. + data = __encode_input(call.abi, call.signature, *args) if len_inputs else call.signature + # We have to do it like this so we don't break the process pool. + if isinstance(data, Exception): + raise data + return data + +async def decode_output(call: ContractCall, data: bytes) -> Any: + __validate_output(call.abi, data) + try: + decoded = await decode(call, data) + # TODO: move this somewhere else + except BrokenProcessPool: + # Let's fix that right up + logger.critical("Oh fuck, you broke the %s while decoding %s with abi %s", ENVS.BROWNIE_DECODER_PROCESSES, data, call.abi) + ENVS.BROWNIE_DECODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.BROWNIE_DECODER_PROCESSES._max_workers) + decoded = __decode_output(data, call.abi) + # We have to do it like this so we don't break the process pool. + if isinstance(decoded, Exception): + raise decoded + return decoded + +async def __request_data_no_args(call: ContractCall) -> str: + return call.signature + +def __encode_input(abi: Dict[str, Any], signature: str, *args: Tuple[Any,...]) -> str: + try: + data = format_input(abi, args) + types_list = get_type_strings(abi["inputs"]) + return signature + eth_abi.encode_abi(types_list, data).hex() + except Exception as e: + return e + +def __decode_output(hexstr: str, abi: Dict[str, Any]) -> Any: + try: + types_list = get_type_strings(abi["outputs"]) + result = eth_abi.decode_abi(types_list, HexBytes(hexstr)) + result = format_output(abi, result) + if len(result) == 1: + result = result[0] + return result + except Exception as e: + return e + +def __validate_output(abi: Dict[str, Any], hexstr: str): + try: + selector = HexBytes(hexstr)[:4].hex() + if selector == "0x08c379a0": + revert_str = eth_abi.decode_abi(["string"], HexBytes(hexstr)[4:])[0] + raise ValueError(f"Call reverted: {revert_str}") + elif selector == "0x4e487b71": + error_code = int(HexBytes(hexstr)[4:].hex(), 16) + if error_code in SOLIDITY_ERROR_CODES: + revert_str = SOLIDITY_ERROR_CODES[error_code] + else: + revert_str = f"Panic (error code: {error_code})" + raise ValueError(f"Call reverted: {revert_str}") + if abi["outputs"] and not hexstr: + raise ValueError("No data was returned - the call likely reverted") + except ValueError as e: + try: + raise VirtualMachineError(e) from None + except: + raise e + diff --git a/dank_mids/brownie_patch/contract.py b/dank_mids/brownie_patch/contract.py index c5b259a2..4c89663a 100644 --- a/dank_mids/brownie_patch/contract.py +++ b/dank_mids/brownie_patch/contract.py @@ -1,17 +1,19 @@ +from types import MethodType from typing import Optional, Union from brownie import Contract, network, web3 from brownie.network.contract import ContractCall, ContractTx, OverloadedMethod -from dank_mids.brownie_patch.call import _patch_call -from dank_mids.brownie_patch.overloaded import _patch_overloaded_method from web3 import Web3 +from dank_mids.brownie_patch.call import _get_coroutine_fn +from dank_mids.brownie_patch.overloaded import _patch_overloaded_method + ContractMethod = Union[ContractCall, ContractTx, OverloadedMethod] def _patch_if_method(method: ContractMethod, w3: Web3) -> None: - if isinstance(method, ContractCall) or isinstance(method, ContractTx): - _patch_call(method, w3) + if isinstance(method, (ContractCall, ContractTx)): + method.coroutine = MethodType(_get_coroutine_fn(w3, len(method.abi['inputs'])), method) elif isinstance(method, OverloadedMethod): _patch_overloaded_method(method, w3) diff --git a/dank_mids/brownie_patch/overloaded.py b/dank_mids/brownie_patch/overloaded.py index 4781aaaf..6fef877f 100644 --- a/dank_mids/brownie_patch/overloaded.py +++ b/dank_mids/brownie_patch/overloaded.py @@ -1,14 +1,15 @@ import functools from types import MethodType -from typing import Any, Coroutine, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union from brownie import Contract from brownie.network.contract import ContractCall, ContractTx, OverloadedMethod -from dank_mids.brownie_patch.call import _patch_call +from dank_mids.brownie_patch.call import _get_coroutine_fn from web3 import Web3 def _patch_overloaded_method(call: OverloadedMethod, w3: Web3) -> None: + # sourcery skip: avoid-builtin-shadow @functools.wraps(call) async def coroutine( self: Contract, @@ -20,20 +21,20 @@ async def coroutine( fn = self._get_fn_from_args(args) except ValueError as e: if f"Contract has more than one function '.{call._name}" in str(e) and f"You must explicitly declare which function you are calling, e.g. .{call._name}" in str(e): - e = str(e) - breakpoint = e.find("(*args)") - raise ValueError(f"{e[:breakpoint]}.coroutine{e[breakpoint:]}") + exc_str = str(e) + breakpoint = exc_str.find("(*args)") + raise ValueError(f"{exc_str[:breakpoint]}.coroutine{exc_str[breakpoint:]}") raise e - + kwargs = {"block_identifier": block_identifier, "override": override} kwargs = {k: v for k, v in kwargs.items() if v is not None} return await fn.coroutine(*args, **kwargs) - for args, method in call.__dict__['methods'].items(): - if isinstance(method, ContractCall) or isinstance(method, ContractTx): - _patch_call(method, w3) + for method in call.__dict__['methods'].values(): + if isinstance(method, (ContractCall, ContractTx)): + method.coroutine = MethodType(_get_coroutine_fn(w3, len(method.abi['inputs'])), method) # TODO implement this properly #elif isinstance(call, ContractTx): #_patch_tx(call, w3) - + call.coroutine = MethodType(coroutine, call) diff --git a/dank_mids/constants.py b/dank_mids/constants.py index b4a5e180..fefb1ee0 100644 --- a/dank_mids/constants.py +++ b/dank_mids/constants.py @@ -2,15 +2,15 @@ import multicall -GAS_LIMIT = multicall.constants.GAS_LIMIT +TOO_MUCH_DATA_ERRS = ["Payload Too Large", "content length too large", "request entity too large"] +RETRY_ERRS = ["connection reset by peer", "server disconnected", "execution aborted (timeout = 5s)"] +GAS_LIMIT = multicall.constants.GAS_LIMIT OVERRIDE_CODE = multicall.constants.MULTICALL2_BYTECODE - - # When you get these call responses back from the multicall, we know there was some problem with execution. # If you make the exact same calls without multicall, you will get an Exception not a response. -# TODO: Replicate brownie's logic for detecting reverts. +# TODO: Delete these BAD_HEXES = [ # Chainlink feeds no access "0x08c379a0000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000094e6f206163636573730000000000000000000000000000000000000000000000", @@ -33,3 +33,8 @@ # Function does not exist "0x08c379a00000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000001846756e6374696f6e20646f6573206e6f742065786973742e0000000000000000", ] + +# Not sure why yet but sometimes a multicall will succeed but one of the results will be a failure for one call that doesn't interrupt the rest of the mcall. +# NOTE: we leave off the '0x' so we can compare raw bytes +# NOTE: The 2nd one here needs to be converted to the first format but I need to encounter one in the wild before I can do that +REVERT_SELECTORS = [b'\x08\xc3y\xa0', b"4e487b71"] diff --git a/dank_mids/controller.py b/dank_mids/controller.py index aff24865..ca933012 100644 --- a/dank_mids/controller.py +++ b/dank_mids/controller.py @@ -1,10 +1,10 @@ -import asyncio +import logging import threading from collections import defaultdict -from time import time -from typing import Any, DefaultDict, List +from typing import Any, DefaultDict, List, Literal, Optional +import eth_retry from eth_utils import to_checksum_address from multicall.constants import MULTICALL2_ADDRESSES, MULTICALL_ADDRESSES from multicall.multicall import NotSoBrightBatcher @@ -13,15 +13,17 @@ from web3.providers.async_base import AsyncBaseProvider from web3.types import RPCEndpoint, RPCResponse -from dank_mids._config import LOOP_INTERVAL +from dank_mids import ENVIRONMENT_VARIABLES as ENVS from dank_mids._demo_mode import demo_logger -from dank_mids.loggers import main_logger, sort_lazy_logger -from dank_mids.requests import RPCRequest, eth_call -from dank_mids.types import BlockId, ChainId -from dank_mids.uid import UIDGenerator -from dank_mids.worker import DankWorker +from dank_mids._exceptions import DankMidsInternalError +from dank_mids.batch import DankBatch +from dank_mids.helpers import decode, session +from dank_mids.requests import JSONRPCBatch, Multicall, RPCRequest, eth_call +from dank_mids.semaphores import MethodSemaphores +from dank_mids.types import BlockId, ChainId, PartialRequest, RawResponse +from dank_mids.uid import UIDGenerator, _AlertingRLock -BYPASS_METHODS = "eth_getLogs", "trace_", "debug_" +logger = logging.getLogger(__name__) instances: DefaultDict[ChainId, List["DankMiddlewareController"]] = defaultdict(list) @@ -32,95 +34,122 @@ def _sync_w3_from_async(w3: Web3) -> Web3: # We can't pickle middlewares to send to process executor. # The call has already passed thru all middlewares on the user's Web3 instance. sync_w3.middleware_onion.clear() - sync_w3.provider.middlewares = tuple() + sync_w3.provider.middlewares = () return sync_w3 class DankMiddlewareController: def __init__(self, w3: Web3) -> None: - main_logger.info('Dank Middleware initializing... Strap on your rocket boots...') + logger.info('Dank Middleware initializing... Strap on your rocket boots...') self.w3: Web3 = w3 self.sync_w3 = _sync_w3_from_async(w3) + self.chain_id = self.sync_w3.eth.chain_id + # NOTE: We need this mutable for node types that require the full jsonrpc spec + self.request_type = PartialRequest + self._time_of_request_type_change = 0 + self.state_override_not_supported: bool = ENVS.GANACHE_FORK or self.chain_id == 100 # Gnosis Chain does not support state override. + + self.endpoint = self.w3.provider.endpoint_uri + if "tenderly" in self.endpoint: + # NOTE: Tenderly does funky things sometimes + logger.warning( + "We see you're using a tenderly rpc.\n" + + "There is a known conflict between dank and tenderly which causes issues not present with other providers.\n" + + "Your milage may vary. Debugging efforts welcome." + ) + + self._instance: int = sum(len(_instances) for _instances in instances.values()) + instances[self.chain_id].append(self) # type: ignore + multicall = MULTICALL_ADDRESSES.get(self.chain_id) multicall2 = MULTICALL2_ADDRESSES.get(self.chain_id) if multicall2 is None: raise NotImplementedError("Dank Mids currently does not support this network.\nTo add support, you just need to submit a PR adding the appropriate multicall contract addresses to this file:\nhttps://github.com/banteg/multicall.py/blob/master/multicall/constants.py") self.multicall2 = to_checksum_address(multicall2) self.no_multicall = {self.multicall2} if multicall is None else {self.multicall2, to_checksum_address(multicall)} - self.pending_eth_calls: List[eth_call] = [] - self.pending_rpc_calls: List[RPCRequest] = [] - self.num_pending_eth_calls: int = 0 - self.worker = DankWorker(self) - self.is_running: bool = False + + self.method_semaphores = MethodSemaphores(self) + self.batcher = NotSoBrightBatcher() + self.call_uid = UIDGenerator() - self._checkpoint: float = time() - self._instance: int = sum(len(_instances) for _instances in instances.values()) - instances[self.chain_id].append(self) # type: ignore + self.multicall_uid: UIDGenerator = UIDGenerator() + self.request_uid: UIDGenerator = UIDGenerator() + self.jsonrpc_batch_uid: UIDGenerator = UIDGenerator() + self.pools_closed_lock = _AlertingRLock(name='pools closed') + + self.pending_eth_calls: DefaultDict[BlockId, Multicall] = defaultdict(lambda: Multicall(self)) + self.pending_rpc_calls = JSONRPCBatch(self) def __repr__(self) -> str: - return f"" + return f"" async def __call__(self, method: RPCEndpoint, params: Any) -> RPCResponse: - return await (eth_call(self, params) if method == "eth_call" else RPCRequest(self, method, params)) # type: ignore [return-value] + call_semaphore = self.method_semaphores[method][params[1]] if method == "eth_call" else self.method_semaphores[method] + async with call_semaphore: + logger.debug(f'making {self.request_type.__name__} {method} with params {params}') + call = eth_call(self, params) if method == "eth_call" and params[0]["to"] not in self.no_multicall else RPCRequest(self, method, params) + return await call - @property - def batcher(self) -> NotSoBrightBatcher: - return self.worker.batcher - - @property - def pools_closed_lock(self) -> threading.Lock: - return self.call_uid.lock - - async def taskmaster_loop(self) -> None: - self.is_running = True - while self.pending_eth_calls or self.pending_rpc_calls: - await asyncio.sleep(0) - if (self.loop_is_ready or self.queue_is_full): - await self.execute_multicall() - self.is_running = False - - async def execute_multicall(self) -> None: - i = 0 - while self.pools_closed_lock.locked(): - if i // 500 == int(i // 500): - main_logger.debug('lock is locked') - i += 1 - await asyncio.sleep(.1) - with self.pools_closed_lock: - eth_calls: DefaultDict[BlockId, List[eth_call]] = defaultdict(list) - for call in self.pending_eth_calls: - eth_calls[call.block].append(call) + @eth_retry.auto_retry + async def make_request(self, method: str, params: List[Any], request_id: Optional[int] = None) -> RawResponse: + request = self.request_type(method=method, params=params, id=request_id or self.call_uid.next) + return await session.post(self.endpoint, data=request, loads=decode.raw) + + async def execute_batch(self) -> None: + with self.pools_closed_lock: # Do we really need this? # NOTE: yes we do + multicalls = dict(self.pending_eth_calls) self.pending_eth_calls.clear() self.num_pending_eth_calls = 0 rpc_calls = self.pending_rpc_calls[:] - self.pending_rpc_calls.clear() - demo_logger.info(f'executing multicall (current cid: {self.call_uid.latest})') # type: ignore - await self.worker.execute_batch(eth_calls, rpc_calls) - - @sort_lazy_logger - def should_batch(self, method: RPCEndpoint, params: Any) -> bool: - """ Determines whether or not a call should be passed to the DankMiddlewareController. """ - if method == "eth_call" and params[0]["to"] == self.multicall2: - # These most likely come from dank mids internals if you're using dank mids. - return False - if any(bypass in method for bypass in BYPASS_METHODS): - main_logger.debug(f"bypassed, method is {method}") - return False - return True + self.pending_rpc_calls = JSONRPCBatch(self) + demo_logger.info(f'executing dank batch (current cid: {self.call_uid.latest})') # type: ignore + batch = DankBatch(self, multicalls, rpc_calls) + await batch + demo_logger.info(f'{batch} done') - @property - def loop_is_ready(self) -> bool: - return time() - self._checkpoint > LOOP_INTERVAL - @property def queue_is_full(self) -> bool: - return bool(len(self.pending_eth_calls) >= self.batcher.step * 25) + with self.pools_closed_lock: + if ENVS.OPERATION_MODE.infura: + return sum(len(call) for call in self.pending_rpc_calls) >= ENVS.MAX_JSONRPC_BATCH_SIZE + eth_calls = sum(len(calls) for calls in self.pending_eth_calls.values()) + other_calls = sum(len(call) for call in self.pending_rpc_calls) + return eth_calls + other_calls >= self.batcher.step + + def early_start(self): + """Used to start all queued calls when we have enough for a full batch""" + with self.pools_closed_lock: + self.pending_rpc_calls.extend(self.pending_eth_calls.values(), skip_check=True) + self.pending_eth_calls.clear() + self.pending_rpc_calls.start() + + def reduce_multicall_size(self, num_calls: int) -> None: + self._reduce_chunk_size(num_calls, "multicall") def reduce_batch_size(self, num_calls: int) -> None: - new_step = round(num_calls * 0.99) if num_calls >= 100 else num_calls - 1 - # NOTE: We need this check because one of the other multicalls in a batch might have already reduced `self.batcher.step` - if new_step < self.batcher.step: - old_step = self.batcher.step - self.batcher.step = new_step - main_logger.warning(f'Multicall batch size reduced from {old_step} to {new_step}. The failed batch had {num_calls} calls.') + self._reduce_chunk_size(num_calls, "jsonrpc batch") + + def _reduce_chunk_size(self, num_calls: int, chunk_name: Literal["multicall", "jsonrpc"]) -> None: + new_chunk_size = round(num_calls * 0.99) if num_calls >= 100 else num_calls - 1 + if new_chunk_size < 30: + logger.warning(f"your {chunk_name} batch size is really low, did you have some connection issue earlier? You might want to restart your script. {chunk_name} chunk size will not be further lowered.") + return + # NOTE: We need the 2nd check because one of the other calls in a batch might have already reduced the chunk size + if chunk_name == "jsonrpc batch": + if new_chunk_size < ENVS.MAX_JSONRPC_BATCH_SIZE: + old_chunk_size = ENVS.MAX_JSONRPC_BATCH_SIZE + ENVS.MAX_JSONRPC_BATCH_SIZE = new_chunk_size + else: + logger.info("new chunk size %s is not lower than max batch size %s", new_chunk_size, str(ENVS.MAX_JSONRPC_BATCH_SIZE)) + return + elif chunk_name == "multicall": + if new_chunk_size < self.batcher.step: + old_chunk_size = self.batcher.step + self.batcher.step = new_chunk_size + else: + logger.info("new chunk size %s is not lower than batcher step %s", new_chunk_size, self.batcher.step) + return + else: + raise DankMidsInternalError(ValueError(f"chunk name {chunk_name} is invalid")) + logger.warning(f'{chunk_name} batch size reduced from {old_chunk_size} to {new_chunk_size}. The failed batch had {num_calls} calls.') diff --git a/dank_mids/executor.py b/dank_mids/executor.py index 1d4d1646..a6b4bfaf 100644 --- a/dank_mids/executor.py +++ b/dank_mids/executor.py @@ -1,112 +1,6 @@ -import queue -import threading -import weakref -from concurrent.futures import ThreadPoolExecutor, _base, thread -TEN_MINUTES = 60 * 10 - -def _worker(executor_reference, work_queue, initializer, initargs, timeout): # NOTE: NEW 'timeout' - if initializer is not None: - try: - initializer(*initargs) - except BaseException: - _base.LOGGER.critical('Exception in initializer:', exc_info=True) - executor = executor_reference() - if executor is not None: - executor._initializer_failed() - return - - try: - while True: - try: # NOTE: NEW - work_item = work_queue.get(block=True, - timeout=timeout) # NOTE: NEW - except queue.Empty: # NOTE: NEW - # Its been 'timeout' seconds and there are no new work items. # NOTE: NEW - # Let's suicide the thread. # NOTE: NEW - executor = executor_reference() # NOTE: NEW - - with executor._adjusting_lock: # NOTE: NEW - # NOTE: We keep a minimum of one thread active to prevent locks - if len(executor) > 1: # NOTE: NEW - t = threading.current_thread() # NOTE: NEW - executor._threads.remove(t) # NOTE: NEW - thread._threads_queues.pop(t) # NOTE: NEW - # Let the executor know we have one less idle thread available - executor._idle_semaphore.acquire(blocking=False) # NOTE: NEW - return # NOTE: NEW - continue - - if work_item is not None: - work_item.run() - # Delete references to object. See issue16284 - del work_item - - # attempt to increment idle count - executor = executor_reference() - if executor is not None: - executor._idle_semaphore.release() - del executor - continue - - executor = executor_reference() - # Exit if: - # - The interpreter is shutting down OR - # - The executor that owns the worker has been collected OR - # - The executor that owns the worker has been shutdown OR - if thread._shutdown or executor is None or executor._shutdown: - # Flag the executor as shutting down as early as possible if it - # is not gc-ed yet. - if executor is not None: - executor._shutdown = True - # Notice other workers - work_queue.put(None) - return - del executor - except BaseException: - _base.LOGGER.critical('Exception in worker', exc_info=True) - -class PruningThreadPoolExecutor(ThreadPoolExecutor): - """ - This ThreadPoolExecutor implementation prunes inactive threads after 'timeout' seconds without a work item. - Pruned threads will be automatically recreated as needed for future workloads. up to 'max_threads' can be active at any one time. - """ - def __init__(self, max_workers=None, thread_name_prefix='', - initializer=None, initargs=(), timeout=TEN_MINUTES): - self._timeout=timeout - self._adjusting_lock = threading.Lock() - super().__init__(max_workers, thread_name_prefix, initializer, initargs) - - def __repr__(self) -> str: - return f"" - - def __len__(self) -> int: - return len(self._threads) - - def _adjust_thread_count(self): - with self._adjusting_lock: - # if idle threads are available, don't spin new threads - if self._idle_semaphore.acquire(timeout=0): - return - - # When the executor gets lost, the weakref callback will wake up - # the worker threads. - def weakref_cb(_, q=self._work_queue): - q.put(None) - - num_threads = len(self._threads) - if num_threads < self._max_workers: - thread_name = '%s_%d' % (self._thread_name_prefix or self, - num_threads) - t = threading.Thread(name=thread_name, target=_worker, - args=(weakref.ref(self, weakref_cb), - self._work_queue, - self._initializer, - self._initargs, - self._timeout)) - t.daemon = True - t.start() - self._threads.add(t) - thread._threads_queues[t] = self._work_queue - -executor = PruningThreadPoolExecutor(128) +import logging +logger = logging.getLogger(__name__) +logger.warning("dank_mids.executor module has been deprecated and will be removed eventually.") +logger.warning("you can now import what you need from a_sync.primitives module https://github.com/BobTheBuidler/ez-a-sync") +from a_sync.primitives.executor import * \ No newline at end of file diff --git a/dank_mids/helpers/__init__.py b/dank_mids/helpers/__init__.py new file mode 100644 index 00000000..16174c0e --- /dev/null +++ b/dank_mids/helpers/__init__.py @@ -0,0 +1,3 @@ + +from dank_mids.helpers.helpers import (await_all, setup_dank_w3, + setup_dank_w3_from_sync) diff --git a/dank_mids/helpers/decode.py b/dank_mids/helpers/decode.py new file mode 100644 index 00000000..df6939a8 --- /dev/null +++ b/dank_mids/helpers/decode.py @@ -0,0 +1,11 @@ + +from msgspec import Raw, json + +from dank_mids.types import (PartialResponse, RawResponse, + _JSONRPCBatchResponse, nested_dict_of_stuff) + +raw = lambda data: RawResponse(json.decode(data, type=Raw)) +nested_dict = lambda data: json.decode(data, type=nested_dict_of_stuff) +jsonrpc_batch = lambda data: decoded if isinstance(decoded := json.decode(data, type=_JSONRPCBatchResponse), PartialResponse) else _map_raw(decoded) + +_map_raw = lambda decoded: list(map(RawResponse, decoded)) diff --git a/dank_mids/helpers.py b/dank_mids/helpers/helpers.py similarity index 88% rename from dank_mids/helpers.py rename to dank_mids/helpers/helpers.py index d2bd7c81..67a21e5d 100644 --- a/dank_mids/helpers.py +++ b/dank_mids/helpers/helpers.py @@ -1,29 +1,38 @@ import asyncio -from typing import (Any, Awaitable, Callable, Coroutine, Iterable, List, - Literal, Optional) +from functools import wraps +from typing import (TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, + Iterable, List, Literal, Optional, TypeVar) from eth_utils.curried import (apply_formatter_if, apply_formatters_to_dict, apply_key_map, is_null) from eth_utils.toolz import assoc, complement, compose, merge from hexbytes import HexBytes from multicall.utils import get_async_w3 +from typing_extensions import ParamSpec from web3 import Web3 from web3._utils.rpc_abi import RPC from web3.providers.async_base import AsyncBaseProvider from web3.providers.base import BaseProvider from web3.types import Formatters, FormattersDict, RPCEndpoint, RPCResponse -from dank_mids.middleware import dank_middleware from dank_mids.types import AsyncMiddleware +if TYPE_CHECKING: + from dank_mids.requests import RPCRequest + dank_w3s: List[Web3] = [] +T = TypeVar("T") +P = ParamSpec("P") + def setup_dank_w3(async_w3: Web3) -> Web3: """ Injects Dank Middleware into an async Web3 instance. """ assert async_w3.eth.is_async and isinstance(async_w3.provider, AsyncBaseProvider) # NOTE: We use this lookup to prevent errs where 2 project dependencies both depend on dank_mids and eth-brownie. if async_w3 not in dank_w3s: + # NOTE: We import here to prevent a circular import + from dank_mids.middleware import dank_middleware async_w3.middleware_onion.inject(dank_middleware, layer=0) async_w3.middleware_onion.add(geth_poa_middleware) dank_w3s.append(async_w3) @@ -38,7 +47,14 @@ async def await_all(futs: Iterable[Awaitable]) -> None: for fut in asyncio.as_completed([*futs]): await fut del fut - + +def set_done(fn: Callable[P, Awaitable[T]]): + @wraps(fn) + async def set_done_wrap(self: "RPCRequest", *args: P.args, **kwargs: P.kwargs) -> T: + retval = await fn(self, *args, **kwargs) + self._done.set() + return retval + return set_done_wrap # Everything below is in web3.py now, but dank_mids currently needs a version that predates them. diff --git a/dank_mids/helpers/session.py b/dank_mids/helpers/session.py new file mode 100644 index 00000000..caff4723 --- /dev/null +++ b/dank_mids/helpers/session.py @@ -0,0 +1,115 @@ + +import http +import logging +from enum import IntEnum +from itertools import chain +from threading import get_ident +from typing import Any, overload + +import msgspec +from aiohttp import ClientSession as DefaultClientSession +from aiohttp import ClientTimeout +from aiohttp.client_exceptions import ClientResponseError +from aiohttp.typedefs import JSONDecoder +from aiolimiter import AsyncLimiter +from async_lru import alru_cache + +from dank_mids import ENVIRONMENT_VARIABLES +from dank_mids.helpers import decode +from dank_mids.types import JSONRPCBatchResponse, PartialRequest, RawResponse + +logger = logging.getLogger("dank_mids.session") + +# NOTE: You cannot subclass an IntEnum object so we have to do some hacky shit here. +# First, set up custom error codes we might see. +class _HTTPStatusExtension(IntEnum): + WEB_SERVER_IS_RETURNING_AN_UNKNOWN_ERROR = (520, 'Web Server is Returning an Unknown Error', + 'HTTP response status code 520 Web server is returning an unknown error is an unofficial server error\n' + + 'that is specific to Cloudflare. This is a catch-all error that is used in the absence of having a\n' + + 'HTTP status code for one that is more specific.\n' + + 'Learn more at https://http.dev/520') + CLOUDFLARE_CONNECTION_TIMEOUT = (522, 'Cloudflare Connection Timeout', + 'Cloudflare is a content delivery network that acts as a gateway between a user and a website server.\n' + + 'When the 522 Connection timed out status code is received, Cloudflare attempted to connect\n' + + 'to the origin server but was not successful due to a timeout. The HTTP Connection was not established\n' + + 'most likely because the IP addresses of Cloudflare are blocked by the origin server,\n' + + 'the origin server settings are misconfigured, or the origin server is overloaded.' + + 'Learn more at https://http.dev/522') + CLOUDFLARE_TIMEOUT = (524, 'Cloudflare Timeout', + '"524 A timeout occurred" is an unofficial server error that is specific to Cloudflare.\n' + + 'When the 524 A timeout occurred status code is received, it implies that a successful\n' + + 'HTTP Connection was made between Cloudflare and the origin server, however the connection\n' + + 'timed out before the HTTP request was completed. Cloudflare typically waits 100 seconds\n' + + 'for an HTTP response and returns this HTTP status code if nothing is received.\n' + + 'Learn more at https://http.dev/524') + SITE_FROZEN = (530, 'Site Frozen', + 'HTTP response status code 530 is an unofficial server error that is specific to Cloudflare\n' + + 'and Pantheon. In the case of Cloudflare, a second HTTP status code 1XXX error message will be\n' + + 'included to provide a more accurate description of the problem. For Pantheon, HTTP status code\n' + + '530 Site Frozen indicates that a site has been restricted due to inactivity.\n' + + 'Learn more at https://http.dev/530') + def __new__(cls, value, phrase, description=''): + obj = int.__new__(cls, value) + obj._value_ = value + obj.phrase = phrase + obj.description = description + return obj + +# Then, combine the standard HTTPStatus enum and the custom extension to get the full custom HTTPStatus enum we need. +HTTPStatusExtended = IntEnum('HTTPStatusExtended', [(i.name, i.value) for i in chain(http.HTTPStatus, _HTTPStatusExtension)]) + +RETRY_FOR_CODES = { + HTTPStatusExtended.BAD_GATEWAY, + HTTPStatusExtended.WEB_SERVER_IS_RETURNING_AN_UNKNOWN_ERROR, + HTTPStatusExtended.CLOUDFLARE_CONNECTION_TIMEOUT, + HTTPStatusExtended.CLOUDFLARE_TIMEOUT, +} + +limiter = AsyncLimiter(1, 0.1) # 10 requests/second + +@overload +async def post(endpoint: str, *args, loads = decode.raw, **kwargs) -> RawResponse:... +@overload +async def post(endpoint: str, *args, loads = decode.jsonrpc_batch, **kwargs) -> JSONRPCBatchResponse:... +async def post(endpoint: str, *args, loads: JSONDecoder = None, **kwargs) -> Any: + """Returns decoded json data from `endpoint`""" + session = await get_session() + return await session.post(endpoint, *args, loads=loads, **kwargs) + +async def get_session() -> "ClientSession": + return await _get_session_for_thread(get_ident()) + +class ClientSession(DefaultClientSession): + async def post(self, endpoint: str, *args, loads: JSONDecoder = None, **kwargs) -> bytes: + # Process input arguments. + if isinstance(kwargs.get('data'), PartialRequest): + logger.debug("making request for %s", kwargs['data']) + kwargs['data'] = msgspec.json.encode(kwargs['data']) + logger.debug("making request with (args, kwargs): (%s %s)", tuple(chain((endpoint), args)), kwargs) + + # Try the request until success or 5 failures. + tried = 0 + while True: + try: + async with limiter: + async with super().post(endpoint, *args, **kwargs) as response: + response = await response.json(loads=loads) + # NOTE: We check this to avoid unnecessary f-string conversions. + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"received response {response}") + return response + except ClientResponseError as e: + if e.status not in RETRY_FOR_CODES or tried >= 5: + logger.debug(f"response failed with status {HTTPStatusExtended(e.status)}.") + raise e + logger.debug(f"response failed with status {HTTPStatusExtended(e.status)}, retrying.") + tried += 1 + +@alru_cache(maxsize=None) +async def _get_session_for_thread(thread_ident: int) -> ClientSession: + """ + This makes our ClientSession threadsafe just in case. + Most everything should be run in main thread though. + """ + timeout = ClientTimeout(ENVIRONMENT_VARIABLES.AIOHTTP_TIMEOUT) + return ClientSession(headers={'content-type': 'application/json'}, timeout=timeout, raise_for_status=True) diff --git a/dank_mids/loggers.py b/dank_mids/loggers.py deleted file mode 100644 index 4388b88a..00000000 --- a/dank_mids/loggers.py +++ /dev/null @@ -1,7 +0,0 @@ -import logging - -from lazy_logging import LazyLoggerFactory - -main_logger = logging.getLogger("dank_mids") -sort_logger = logging.getLogger("dank_mids.should_batch") -sort_lazy_logger = LazyLoggerFactory("DANKMIDS_SHOULDBATCH")(sort_logger) diff --git a/dank_mids/middleware.py b/dank_mids/middleware.py index dc62599a..f3a9757b 100644 --- a/dank_mids/middleware.py +++ b/dank_mids/middleware.py @@ -1,23 +1,26 @@ -from typing import Any, Callable +import logging +from threading import Thread, current_thread +from typing import Any, Callable, Dict, Tuple from web3 import Web3 -from web3.types import RPCEndpoint, RPCResponse +from web3.types import RPCEndpoint -from dank_mids.semaphore import method_semaphores +from dank_mids.controller import DankMiddlewareController from dank_mids.types import AsyncMiddleware +_controllers: Dict[Tuple[Web3, Thread], DankMiddlewareController] = {} + +logger = logging.getLogger(__name__) async def dank_middleware( make_request: Callable[[RPCEndpoint, Any], Any], web3: Web3 ) -> AsyncMiddleware: - # We import here to avoid a circular import issue - from dank_mids.controller import DankMiddlewareController - dank_mids = DankMiddlewareController(web3) - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - async with method_semaphores[method]: - if dank_mids.should_batch(method, params): - return await dank_mids(method, params) - return await make_request(method, params) - return middleware + return DankMiddlewareController(web3) +""" testing something + cache_key = web3, current_thread() + if (cache_key) not in _controllers: + logger.debug(f"cache key {cache_key} not known to dank mids, starting new controller") + _controllers[(cache_key)] = DankMiddlewareController(web3) + return _controllers[(cache_key)]""" diff --git a/dank_mids/requests.py b/dank_mids/requests.py index 39940f5f..49ad53bd 100644 --- a/dank_mids/requests.py +++ b/dank_mids/requests.py @@ -1,83 +1,69 @@ import abc import asyncio +import logging +import time from collections import defaultdict +from concurrent.futures.process import BrokenProcessPool +from contextlib import suppress from typing import (TYPE_CHECKING, Any, DefaultDict, Dict, Generator, Generic, - Iterable, Iterator, List, Optional, Tuple, TypeVar, Union) + Iterable, Iterator, List, NoReturn, Optional, Tuple, + TypeVar, Union) -import aiohttp +import a_sync import eth_retry -from aiohttp import RequestInfo -from eth_abi import decode_single, encode_single +import msgspec +from a_sync import AsyncProcessPoolExecutor, PruningThreadPoolExecutor +from aiohttp.client_exceptions import ClientResponseError +from eth_abi import abi, decoding from eth_typing import ChecksumAddress from eth_utils import function_signature_to_4byte_selector from hexbytes import HexBytes -from hexbytes._utils import to_bytes -from multicall.utils import run_in_subprocess -from web3 import Web3 -from web3.datastructures import AttributeDict -from web3.types import RPCEndpoint, RPCError, RPCResponse - -from dank_mids._config import (AIOHTTP_TIMEOUT, DEMO_MODE, - MAX_JSONRPC_BATCH_SIZE) +from typing_extensions import Self +from web3.types import RPCEndpoint, RPCResponse + +from dank_mids import ENVIRONMENT_VARIABLES as ENVS +from dank_mids import constants, stats from dank_mids._demo_mode import demo_logger -from dank_mids.constants import BAD_HEXES, OVERRIDE_CODE -from dank_mids.helpers import await_all -from dank_mids.loggers import main_logger -from dank_mids.types import BatchId, BlockId, JsonrpcParams, RpcCallJson +from dank_mids._exceptions import (BadResponse, DankMidsClientResponseError, + DankMidsInternalError, EmptyBatch, + PayloadTooLarge, ResponseNotReady, + internal_err_types) +from dank_mids.helpers import decode, session +from dank_mids.helpers.helpers import set_done +from dank_mids.types import (BatchId, BlockId, JSONRPCBatchResponse, + JsonrpcParams, PartialRequest, PartialResponse, + RawResponse, Request, Response) +from dank_mids.uid import _AlertingRLock if TYPE_CHECKING: from dank_mids.controller import DankMiddlewareController - from dank_mids.worker import DankWorker - -RETRY_ERRS = ["connection reset by peer","request entity too large","server disconnected","execution aborted (timeout = 5s)"] -class ResponseNotReady(Exception): - pass -def _call_failed(data: Optional[Union[bytes, Exception]]) -> bool: - """ Returns True if `data` indicates a failed response, False otherwise. """ - if data is None: - return True - # If we got a known "bad" response from the multicall, that is also a failure. - # Most likely the target contract does not support multicall2. - elif (isinstance(data, bytes) and HexBytes(data).hex() in BAD_HEXES): - return True - return False +logger = logging.getLogger(__name__) -def _reattempt_call_and_return_exception(target: ChecksumAddress, calldata: bytes, block: BlockId, w3: Web3) -> Union[bytes, Exception]: - """ NOTE: This runs synchronously in a subprocess in order to bypass Dank Middleware without blocking the event loop. """ - try: - return w3.eth.call({"to": target, "data": calldata}, block) - except Exception as e: - return e +_Response = TypeVar("_Response", Response, List[Response], RPCResponse, List[RPCResponse]) -def _err_response(e: Exception) -> RPCError: - """ Extract an error message from `e` to use in a spoof rpc response. """ - if isinstance(e.args[0], str) or isinstance(e.args[0], RequestInfo): - err_msg = f"DankMidsError: {e.__class__.__name__}: {e.args}" - elif isinstance(e.args[0], Exception): - err_msg = f"DankMidsError: {e.args[0].__class__.__name__}: {e.args[0].args}" - elif not hasattr(e.args[0], '__contains__'): - err_msg = f"DankMidsError: {e.__class__.__name__}: {e.args}" - elif "message" in e.args[0]: - err_msg = e.args[0]["message"] - elif "error" in e.args[0] and hasattr(e.args[0]["error"], '__contains__') and "message" in e.args[0]["error"]: - err_msg = e.args[0]["error"]["message"] - else: - raise e - return {'code': -32000, 'message': err_msg, 'data': ''} - - -_Response = TypeVar("_Response", RPCResponse, List[RPCResponse]) +class _RequestEvent(a_sync.Event): + def __init__(self, owner: "_RequestMeta") -> None: + super().__init__(debug_daemon_interval=300) + self._owner = owner + def __repr__(self) -> str: + return f"<{self.__class__.__name__} object at {hex(id(self))} [{'set' if self.is_set() else 'unset'}, waiter:{self._owner}>" + def set(self): + # Make sure we wake up the _RequestEvent's event loop if its in another thread + if asyncio.get_running_loop() == self._loop: + super().set() + else: + self._loop.call_soon_threadsafe(super().set) class _RequestMeta(Generic[_Response], metaclass=abc.ABCMeta): + controller: "DankMiddlewareController" def __init__(self) -> None: - if isinstance(self, RPCRequest): - self.uid = self.controller.call_uid.next - elif isinstance(self, _Batch): - self.uid = self.worker.controller.call_uid.next + self.uid = self.controller.call_uid.next self._response: Optional[_Response] = None + self._done = _RequestEvent(self) + self._start = time.time() def __await__(self) -> Generator[Any, None, Optional[_Response]]: return self.get_response().__await__() @@ -86,13 +72,9 @@ def __await__(self) -> Generator[Any, None, Optional[_Response]]: def __len__(self) -> int: pass - @property - def is_complete(self) -> bool: - return self._response is not None - @property def response(self) -> _Response: - if self._response is None: + if not self._done.is_set(): raise ResponseNotReady(self) return self._response @@ -100,61 +82,195 @@ def response(self) -> _Response: async def get_response(self) -> Optional[_Response]: pass + async def _debug_daemon(self) -> NoReturn: + while not self._done.is_set(): + await asyncio.sleep(60) + if not self._done.is_set(): + logger.debug(f"{self} has not received data after {time.time() - self._start}s") + ### Single requests: -class RPCRequest(_RequestMeta[RPCResponse]): - def __init__(self, controller: "DankMiddlewareController", method: RPCEndpoint, params: Any): +BYPASS_METHODS = "eth_getLogs", "trace_", "debug_" + +class RPCRequest(_RequestMeta[RawResponse]): + dict_responses = set() + str_responses = set() + + _types = set() + + def __init__(self, controller: "DankMiddlewareController", method: RPCEndpoint, params: Any, retry: bool = False): self.controller = controller self.method = method self.params = params + self.should_batch = all(bypass not in method for bypass in BYPASS_METHODS) + self._started = False + self._retry = retry super().__init__() - - if isinstance(self, eth_call) and self.multicall_compatible: - self.controller.pending_eth_calls.append(self) - else: - self.controller.pending_rpc_calls.append(self) + + with self.controller.pools_closed_lock: + if isinstance(self, eth_call) and self.multicall_compatible: + self.controller.pending_eth_calls[self.block].append(self) + else: + self.controller.pending_rpc_calls.append(self) demo_logger.info(f'added to queue (cid: {self.uid})') # type: ignore + if logger.isEnabledFor(logging.DEBUG): + self._daemon = asyncio.create_task(self._debug_daemon()) def __eq__(self, __o: object) -> bool: - if not isinstance(__o, self.__class__): - return False - return self.uid == __o.uid - - def __hash__(self) -> int: - return self.uid + return self.uid == __o.uid if isinstance(__o, self.__class__) else False def __len__(self) -> int: + # NOTE: These are totally arbitrary + if self.method == "eth_getTransactionReceipt": + return 10 + elif any(m in self.method for m in ["eth_getCode" "eth_getBlockBy", "eth_getTransaction"]): + return 6 return 1 def __repr__(self) -> str: return f"<{self.__class__.__name__} uid={self.uid} method={self.method}>" @property - def rpc_data(self) -> RpcCallJson: - return {'jsonrpc': '2.0', 'id': self.uid, 'method': self.method, 'params': self.params} + def request(self) -> Union[Request, PartialRequest]: + return self.controller.request_type(method=self.method, params=self.params, id=self.uid) + + def start(self, batch: "_Batch") -> None: + self._started = True + self._batch = batch + @set_done async def get_response(self) -> RPCResponse: - if not self.controller.is_running: - await self.controller.taskmaster_loop() - while not self.is_complete: + if not self.should_batch: + logger.debug(f"bypassed, method is {self.method}") + try: + await asyncio.wait_for(self.make_request(), timeout=ENVS.STUCK_CALL_TIMEOUT) + except asyncio.TimeoutError: + return await self.create_duplicate() + return self.response.decode(partial=True).to_dict(self.method) + + if self._started and not self._batch._started: + # NOTE: If we're already started, we filled a batch. Let's await it now so we can send something to the node. + await self._batch + if not self._started: + # NOTE: We want to force the event loop to make one full _run_once call before we execute. await asyncio.sleep(0) + if not self._started: + try: + await asyncio.wait_for( + # If this timeout fails, we go nuclear and destroy the batch. + # Any calls that already succeeded will have already completed on the client side. + # Any calls that have not yet completed with results will be recreated, rebatched (potentially bringing better results?), and retried + self.controller.execute_batch(), + timeout=ENVS.STUCK_CALL_TIMEOUT, + ) + except asyncio.TimeoutError: + return await self.create_duplicate() + + try: + await asyncio.wait_for(self._done.wait(), timeout=ENVS.STUCK_CALL_TIMEOUT) + except asyncio.TimeoutError: + return await self.create_duplicate() + + # JIT json decoding + if isinstance(self.response, RawResponse): + response = self.response.decode(partial=True).to_dict(self.method) + if 'error' in response: + if response['error']['message'] == 'invalid request': + if self.controller._time_of_request_type_change == 0: + self.controller.request_type = Request + self.controller._time_of_request_type_change = time.time() + if time.time() - self.controller._time_of_request_type_change <= 600: + logger.info("your node says the partial request was invalid but its okay, we can use the full jsonrpc spec instead") + return await self.controller(self.method, self.params) + response['error']['dankmids_added_context'] = self.request.to_dict() + # I'm 99.99999% sure that any errd call has no result and we only get this field from mscspec object defs + # But I'll check it anyway to be safe + if result := response.pop('result', None): + response['result'] = result + logger.debug("error response for %s: %s", self, response) + return response + + # If we have an Exception here it came from the goofy sync_call thing I need to get rid of. + # We raise it here so it traces back up to the caller + if isinstance(self.response, ClientResponseError): + raise DankMidsClientResponseError(self.response, self.request) from self.response + if isinstance(self.response, Exception): + try: + more_detailed_exc = self.response.__class__(self.response, self.request) + except Exception as e: + self.response.request = self.request + self.response._dank_mids_exception = e + raise self.response + raise more_detailed_exc from None + # Less optimal decoding + # TODO: refactor this out return self.response - async def spoof_response(self, data: Union[str, AttributeDict, Exception]) -> None: - spoof = {"id": self.uid, "jsonrpc": "dank_mids"} - if isinstance(data, Exception): - spoof["error"] = _err_response(data) - else: - spoof["result"] = data # type: ignore - if isinstance(self, eth_call): - main_logger.debug(f"method: eth_call address: {self.target} spoof: {spoof}") + @set_done + async def spoof_response(self, data: Union[RawResponse, bytes, Exception]) -> None: + # sourcery skip: merge-duplicate-blocks + """ + `Raw` type data comes from rpc calls executed in a jsonrpc batch + `bytes` type data comes for individual eth_calls that were batched into multicalls and already decoded + `Exception` type data comes from failed calls + """ + + # New handler + if isinstance(data, RawResponse): + self._response = data + elif isinstance(data, BadResponse): + if data.response.error.message == 'invalid request': + if self.controller._time_of_request_type_change == 0: + self.controller.request_type = Request + self.controller._time_of_request_type_change = time.time() + if time.time() - self.controller._time_of_request_type_change <= 600: + logger.info("your node says the partial request was invalid but its okay, we can use the full jsonrpc spec instead") + self._response = await self.create_duplicate() + return + error = data.response.error.to_dict() + error['dankmids_added_context'] = self.request.to_dict() + self._response = {"error": error} + logger.debug("%s _response set to rpc error response %s", self, self._response) + elif isinstance(data, Exception): + logger.debug("%s _response set to Exception %s", self, data) + self._response = data + # From multicalls + elif isinstance(data, bytes): + self._response = {"result": data} else: - main_logger.debug(f"method: {self.method} spoof: {spoof}") - self._response = spoof # type: ignore + raise NotImplementedError(f'type {type(data)} not supported for spoofing.', type(data), data) + + @set_done + async def make_request(self) -> RawResponse: + """Used to execute the request with no batching.""" + self._started = True + self._response = await self.controller.make_request(self.method, self.params, request_id=self.uid) + return self._response + + @property + def semaphore(self) -> a_sync.Semaphore: + # NOTE: We cannot cache this property so the semaphore control pattern in the `duplicate` fn will work as intended + semaphore = self.controller.method_semaphores[self.method] + if self.method == 'eth_call': + semaphore = semaphore[self.params[1]] + return semaphore + + async def create_duplicate(self) -> Self: # Not actually self, but for typing purposes it is. + # We need to make room since the stalled call is still holding the semaphore + self.semaphore.release() + # We need to check the semaphore again to ensure we have the right context manager, soon but not right away. + # Creating the task before awaiting the new call ensures the new call will grab the semaphore immediately + # and then the task will try to acquire at the very next event loop _run_once cycle + logger.warning(f"call {self.uid} got stuck, we're creating a new one") + retval = await self.controller(self.method, self.params) + await self.semaphore.acquire() + return retval + +revert_threads = PruningThreadPoolExecutor(4) class eth_call(RPCRequest): - def __init__(self, controller: "DankMiddlewareController", params: Any) -> None: + def __init__(self, controller: "DankMiddlewareController", params: Any, retry: bool = False) -> None: """ Adds a call to the DankMiddlewareContoller's `pending_eth_calls`. """ super().__init__(controller, "eth_call", params) # type: ignore @@ -174,25 +290,34 @@ def multicall_compatible(self) -> bool: def target(self) -> str: return self.params[0]["to"] - async def spoof_response(self, data: Union[bytes, Exception]) -> None: # type: ignore + async def spoof_response(self, data: Union[bytes, Exception, RawResponse]) -> None: # type: ignore """ Sets and returns a spoof rpc response for this BatchedCall instance using data provided by the worker. """ - # NOTE: If multicall failed, make sync call to get either: - # - revert details - # - successful response - if _call_failed(data): - data = await self.sync_call() - await super().spoof_response(data.hex() if isinstance(data, bytes) else data) - - async def sync_call(self) -> Union[bytes, Exception]: - """ Used to bypass DankMiddlewareController. """ - data = await run_in_subprocess( - _reattempt_call_and_return_exception, self.target, self.calldata, self.block, self.controller.sync_w3 - ) - # If we were able to get a usable response from single call, add contract to `do_not_batch`. - if not isinstance(data, Exception): - self.controller.no_multicall.add(self.target) # type: ignore - return data + # NOTE: If `type(data)` is `bytes`, it is a result from a multicall. If not, `data` comes from a jsonrpc batch. + # If this if clause is True, it means the call reverted inside of a multicall but returned a result, without causing the multicall to revert. + if isinstance(data, bytes) and any(data.startswith(selector) for selector in constants.REVERT_SELECTORS): + # TODO figure out how to include method selector in no_multicall key + try: + # NOTE: If call response from multicall indicates failure, make sync call to get either: + # - successful response + # - revert details from exception + # If we get a successful response, most likely the target contract does not support multicall2. + # TODO: Get rid of the sync executor and just use `make_request` + data = await asyncio.get_event_loop().run_in_executor(revert_threads, self.controller.sync_w3.eth.call, {"to": self.target, "data": self.calldata}, self.block) + # The single call was successful. We don't want to include this contract in more multicalls + self.controller.no_multicall.add(self.target) + except Exception as e: + # NOTE: The call still returns a revert when it's not packed in a multicall + data = e + # The above revert catching logic fails to account for pre-decoding RawResponse objects. + await super().spoof_response(data) + + @property + def semaphore(self) -> a_sync.Semaphore: + # NOTE: We cannot cache this property so the semaphore control pattern in the `duplicate` fn will work as intended + return self.controller.method_semaphores['eth_call'][self.block] + + ### Batch requests: @@ -201,9 +326,12 @@ async def sync_call(self) -> Union[bytes, Exception]: class _Batch(_RequestMeta[List[RPCResponse]], Iterable[_Request]): calls: List[_Request] - def __init__(self, worker: "DankWorker", calls: Iterable[_Request]): - self.worker: DankWorker = worker + def __init__(self, controller: "DankMiddlewareController", calls: Iterable[_Request]): + self.controller = controller self.calls = list(calls) # type: ignore + self._fut = None + self._lock = _AlertingRLock(name=self.__class__.__name__) + #self._len = 0 super().__init__() def __bool__(self) -> bool: @@ -218,13 +346,6 @@ def __iter__(self) -> Iterator[_Request]: def __len__(self) -> int: return len(self.calls) - def append(self, call: _Request) -> None: - self.calls.append(call) - - @property - def controller(self) -> "DankMiddlewareController": - return self.worker.controller - @property def halfpoint(self) -> int: return len(self) // 2 @@ -242,28 +363,71 @@ def chunk0(self) -> List[_Request]: def chunk1(self) -> List[_Request]: return self.calls[self.halfpoint:] + def append(self, call: _Request, skip_check: bool = False) -> None: + with self._lock: + self.calls.append(call) + #self._len += 1 + if not skip_check: + if self.is_full: + self.start() + elif self.controller.queue_is_full: + self.controller.early_start() + + def extend(self, calls: Iterable[_Request], skip_check: bool = False) -> None: + with self._lock: + self.calls.extend(calls) + #self._len += len(calls) + if not skip_check: + if self.is_full: + self.start() + elif self.controller.queue_is_full: + self.controller.early_start() + + def start(self, batch: Optional["_Batch"] = None, cleanup=True) -> None: + if logger.isEnabledFor(logging.DEBUG): + self._daemon = asyncio.create_task(self._debug_daemon()) + with self._lock: + for call in self.calls: + call.start(batch or self) + if cleanup: + self._post_future_cleanup() + def should_retry(self, e: Exception) -> bool: + """Should the _Batch be retried based on `e`?""" if "out of gas" in f"{e}": # TODO Remember which contracts/calls are gas guzzlers - main_logger.debug('out of gas. cut in half, trying again') - elif any(err in f"{e}".lower() for err in RETRY_ERRS): + logger.debug('out of gas. cut in half, trying again') + elif isinstance(e, PayloadTooLarge) or any(err in f"{e}".lower() for err in constants.RETRY_ERRS): # TODO: use these exceptions to optimize for the user's node - main_logger.debug('Dank too loud. Bisecting batch and retrying.') + logger.debug('Dank too loud. Bisecting batch and retrying.') elif "error processing call Revert" not in f"{e}": - main_logger.warning(f"unexpected {e.__class__.__name__}: {e}") + logger.warning(f"unexpected {e.__class__.__name__}: {e}") return len(self) > 1 +mcall_encoder = abi.default_codec._registry.get_encoder("(bool,(address,bytes)[])") +mcall_decoder = abi.default_codec._registry.get_decoder("(uint256,uint256,(bool,bytes)[])") + +def mcall_encode(data: List[Tuple[bool, bytes]]) -> bytes: + return mcall_encoder([False, data]) + +def mcall_decode(data: PartialResponse) -> List[Tuple[bool, bytes]]: + try: + # NOTE: We need to safely bring any Exceptions back out of the ProcessPool + data = bytes.fromhex(data.decode_result("eth_call")[2:]) + return mcall_decoder(decoding.ContextFramesBytesIO(data))[2] + except Exception as e: + return e + class Multicall(_Batch[eth_call]): - """ Runs in worker thread. """ method = "eth_call" - fourbyte = function_signature_to_4byte_selector("tryBlockAndAggregate(bool,(address,bytes)[])") - input_types = "(bool,(address,bytes)[])" - output_types = "(uint256,uint256,(bool,bytes)[])" + fourbyte = function_signature_to_4byte_selector("tryBlockAndAggregate(bool,(address,bytes)[])") - def __init__(self, worker: "DankWorker", calls: List[eth_call] = [], bid: Optional[BatchId] = None): - super().__init__(worker, calls) - self.bid = bid or self.worker.multicall_uid.next + def __init__(self, controller: "DankMiddlewareController", calls: List[eth_call] = [], bid: Optional[BatchId] = None): + # sourcery skip: default-mutable-arg + super().__init__(controller, calls) + self.bid = bid or self.controller.multicall_uid.next + self._started = False def __repr__(self) -> str: return f"" @@ -274,35 +438,55 @@ def block(self) -> BlockId: @property def calldata(self) -> str: - return (self.fourbyte + encode_single(self.input_types, [False, [[call.target, call.calldata] for call in self.calls]])).hex() + return (self.fourbyte + mcall_encode([[call.target, call.calldata] for call in self.calls])).hex() @property def target(self) -> ChecksumAddress: - return self.worker.target + return self.controller.multicall2 @property def params(self) -> JsonrpcParams: - if self.worker.state_override_not_supported: - return [{'to': self.target, 'data': '0x' + self.calldata}, self.block] # type: ignore - return [{'to': self.target, 'data': '0x' + self.calldata}, self.block, {self.target: {'code': OVERRIDE_CODE}}] # type: ignore + if self.controller.state_override_not_supported: + return [{'to': self.target, 'data': f'0x{self.calldata}'}, self.block] # type: ignore + return [{'to': self.target, 'data': f'0x{self.calldata}'}, self.block, {self.target: {'code': constants.OVERRIDE_CODE}}] # type: ignore @property - def rpc_data(self) -> RpcCallJson: - return {'jsonrpc': '2.0', 'id': self.uid, 'method': self.method, 'params': self.params} + def request(self) -> Union[Request, PartialRequest]: + return self.controller.request_type(method=self.method, params=self.params, id=self.uid) + + @property + def is_full(self) -> bool: + return len(self) >= self.controller.batcher.step - async def get_response(self) -> List[RPCResponse]: - rid = self.worker.request_uid.next + async def get_response(self) -> None: + if self._started: + logger.error(f'{self} early exit') + return + self._started = True + #if len(self) < 50: # TODO play with later + # return await JSONRPCBatch(self.controller, self.calls) + rid = self.controller.request_uid.next demo_logger.info(f'request {rid} for multicall {self.bid} starting') # type: ignore try: - await self.spoof_response(await self.worker(*self.params)) + await self.spoof_response(await self.controller.make_request(self.method, self.params, request_id=self.uid)) + except internal_err_types.__args__ as e: + raise DankMidsInternalError(e) + except ClientResponseError as e: + if e.message == "Payload Too Large": + logger.info("Payload too large. response headers: %s", e.headers) + self.controller.reduce_multicall_size(len(self)) + else: + _log_exception(e) + await (self.bisect_and_retry(e) if self.should_retry(e) else self.spoof_response(e)) # type: ignore [misc] except Exception as e: - await (self.bisect_and_retry() if self.should_retry(e) else self.spoof_response(e)) # type: ignore [misc] + _log_exception(e) + await (self.bisect_and_retry(e) if self.should_retry(e) else self.spoof_response(e)) # type: ignore [misc] demo_logger.info(f'request {rid} for multicall {self.bid} complete') # type: ignore def should_retry(self, e: Exception) -> bool: - if any(err in f"{e}".lower() for err in RETRY_ERRS): - main_logger.debug('dank too loud, trying again') - self.controller.reduce_batch_size(len(self)) + """Should the Multicall be retried based on `e`?""" + if any(err in f"{e}".lower() for err in constants.RETRY_ERRS): + logger.debug('dank too loud, trying again') return True elif "No state available for block" in f"{e}": # NOTE: While it might look weird, f-string is faster than `str(e)`. e.args[0]["dankmids_note"] = "You're not using an archive node, and you need one for the application you are attempting to run." @@ -311,127 +495,259 @@ def should_retry(self, e: Exception) -> bool: return True return len(self) > 1 - async def spoof_response(self, data: Union[bytes, str, Exception]) -> None: - """ - If called from `self`, `response` will be bytes type. - if called from a JSONRPCBatch, `response` will be str type. - """ + @set_done + async def spoof_response(self, data: Union[RawResponse, Exception]) -> None: + # This happens if an Exception takes place during a singular Multicall request. if isinstance(data, Exception): - await await_all(call.spoof_response(data) for call in self.calls) + logger.debug("%s had Exception %s", self, data) + logger.debug("propagating the %s to all %s's calls", data.__class__.__name__, self) + await asyncio.gather(*[call.spoof_response(data) for call in self.calls]) + # A `RawResponse` represents either a successful or a failed response, stored as pre-decoded bytes. + # It was either received as a response to a single rpc call or as a part of a batch response. + elif isinstance(data, RawResponse): + response = data.decode(partial=True) + if response.error: + logger.debug("%s received an 'error' response from the rpc: %s", self, response.exception) + # NOTE: We raise the exception which will be caught, call will be broken up and retried + raise response.exception + logger.debug("%s received valid bytes from the rpc", self) + await asyncio.gather(*(call.spoof_response(data) for call, (_, data) in zip(self.calls, await self.decode(response)))) else: - decoded: List[Tuple[bool, bytes]] - _, _, decoded = await run_in_subprocess(decode_single, self.output_types, to_bytes(data)) - await await_all(call.spoof_response(data) for call, (_, data) in zip(self.calls, decoded)) + raise NotImplementedError(f"type {type(data)} not supported.", data) - async def bisect_and_retry(self) -> List[RPCResponse]: - await await_all((Multicall(self.worker, chunk, f"{self.bid}_{i}") for i, chunk in enumerate(self.bisected))) + async def decode(self, data: PartialResponse) -> List[Tuple[bool, bytes]]: + start = time.time() + if ENVS.OPERATION_MODE.infura: + retval = mcall_decode(data) + else: + try: # NOTE: Quickly check for length without counting each item with `len`. + if not ENVS.OPERATION_MODE.application: + self[100] + retval = await ENVS.MULTICALL_DECODER_PROCESSES.run(mcall_decode, data) + except IndexError: + retval = mcall_decode(data) + except BrokenProcessPool: + # TODO: Move this somewhere else + logger.critical("Oh fuck, you broke the %s while decoding %s", ENVS.MULTICALL_DECODER_PROCESSES, data) + ENVS.MULTICALL_DECODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.MULTICALL_DECODER_PROCESSES._max_workers) + retval = mcall_decode(data) + stats.log_duration(f"multicall decoding for {len(self)} calls", start) + # Raise any Exceptions that may have come out of the process pool. + if isinstance(retval, Exception): + raise retval + return retval + + @set_done + async def bisect_and_retry(self, e: Exception) -> List[RPCResponse]: + """ + Splits up the calls of a `Multicall` into 2 chunks, then awaits both. + Calls `self._done.set()` when finished. + """ + logger.debug("%s had exception %s, bisecting and retrying", self, e) + batches = [Multicall(self.controller, chunk, f"{self.bid}_{i}") for i, chunk in enumerate(self.bisected)] + for batch in batches: + batch.start(cleanup=False) + for batch, result in zip(batches, await asyncio.gather(*batches, return_exceptions=True)): + if isinstance(result, Exception): + if not isinstance(result, DankMidsInternalError): + logger.error(f"That's not good, there was an exception in a {batch.__class__.__name__}. These are supposed to be handled.\n{result}\n", exc_info=True) + raise result + + def _post_future_cleanup(self) -> None: + with suppress(KeyError): + with self.controller.pools_closed_lock: + # This will have already taken place in a full json batch of multicalls + self.controller.pending_eth_calls.pop(self.block) class JSONRPCBatch(_Batch[Union[Multicall, RPCRequest]]): - def __init__(self, worker: "DankWorker", calls: List[Union[Multicall, RPCRequest]] = [], jid: Optional[BatchId] = None) -> None: - super().__init__(worker, calls) - self.jid = jid or self.worker.jsonrpc_batch_uid.next - self._locked = False + def __init__( + self, + controller: "DankMiddlewareController", + calls: List[Union[Multicall, RPCRequest]] = [], + jid: Optional[BatchId] = None + ) -> None: # sourcery skip: default-mutable-arg + super().__init__(controller, calls) + self.jid = jid or self.controller.jsonrpc_batch_uid.next + self._started = False + + def __repr__(self) -> str: + return f"" @property - def data(self) -> List[RpcCallJson]: - return [call.rpc_data for call in self.calls] + def data(self) -> bytes: + if not self.calls: + raise EmptyBatch(f"batch {self.uid} is empty and should not be processed.") + return msgspec.json.encode([call.request for call in self.calls]) @property def is_multicalls_only(self) -> bool: - return all(isinstance(call, Multicall) for call in self.calls) + with self._lock: + return all(isinstance(call, Multicall) for call in self.calls) @property def is_single_multicall(self) -> bool: - return len(self) == 1 and self.is_multicalls_only + with self._lock: + return len(self) == 1 and self.is_multicalls_only @property def method_counts(self) -> Dict[RPCEndpoint, int]: counts: DefaultDict[RPCEndpoint, int] = defaultdict(int) - for call in self.calls: - counts[call.method] += len(call) # type: ignore - return dict(counts) + with self._lock: + for call in self.calls: + if isinstance(call, Multicall): + counts["eth_call[multicall]"] += len(call) # type: ignore + else: + counts[call.method] += 1 + return dict(counts) @property def total_calls(self) -> int: - return sum(len(call) for call in self.calls) + with self._lock: + return sum(len(call) for call in self.calls) + + @property + def is_full(self) -> bool: + with self._lock: + return self.total_calls >= self.controller.batcher.step or len(self) >= ENVS.MAX_JSONRPC_BATCH_SIZE - def append(self, call: Union[Multicall, RPCRequest]) -> None: - if self._locked: - raise Exception(f"{self} is locked.") - self.calls.append(call) - async def get_response(self) -> None: - """ Runs in worker thread. """ - self._locked = True - rid = self.worker.request_uid.next - if DEMO_MODE: + if self._started: + logger.warning(f"{self} exiting early. This shouldn't really happen bro") + return + self._started = True + rid = self.controller.request_uid.next + if ENVS.DEMO_MODE: # When demo mode is disabled, we can save some CPU time by skipping this sum demo_logger.info(f'request {rid} for jsonrpc batch {self.jid} ({sum(len(batch) for batch in self.calls)} calls) starting') # type: ignore try: - responses = await self.post() - self.validate_responses(responses) - await self.spoof_response(responses) + # NOTE: We do this inline so we never have to allocate the response to memory + await self.spoof_response(await self.post()) + # I want to see these asap when working on the lib. + except internal_err_types.__args__ as e: + raise DankMidsInternalError(e) from e + except EmptyBatch as e: + logger.warning("These EmptyBatch exceptions shouldn't actually happen and this except clause can probably be removed soon.") + except PayloadTooLarge as e: + # TODO: track too large payloads and do some better optimizations for batch sizing + self.adjust_batch_size() + await self.bisect_and_retry(e) except Exception as e: - await (self.bisect_and_retry() if self.should_retry(e) else self.spoof_response(e)) + _log_exception(e) + stats.log_errd_batch(self) + if self.should_retry(e): + await self.bisect_and_retry(e) + # NOTE: `self.should_retry(e)` can only return False here if the json batch is comprised of just one rpc request that is not a multicall. + # I include this elif clause as a failsafe. This is rare and should not impact performance. + elif not self.is_single_multicall: + # Just to force my IDE to resolve types correctly + calls : List[RPCRequest] = self.calls + logger.debug("%s had exception %s, aborting and setting Exception as call._response", self, e) + # NOTE: This means an exception occurred during the post request + # AND that the json batch is made of just one rpc request that is not a multicall. + logger.info('does this ever actually run? pretty sure a single-multicall json batch is not possible. can I delete this?') + await asyncio.gather(*[call.spoof_response(e) for call in calls]) + else: + raise NotImplementedError('and you may ask yourself, well, how did I get here?') demo_logger.info(f'request {rid} for jsonrpc batch {self.jid} complete') # type: ignore @eth_retry.auto_retry - async def post(self) -> Union[Dict, List[bytes]]: - """ Posts `jsonrpc_batch` to your node. A successful call returns a list. """ - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - responses = await session.post(self.worker.endpoint, json=self.data) # type: ignore - try: - return await responses.json(content_type=responses.content_type) - except: - counts = self.method_counts - decoded = responses._body.decode() - main_logger.info(f"json batch id: {self.jid} | len: {len(self)} | total calls: {self.total_calls}", ) - main_logger.info(f"methods called: {counts}") - if 'content length too large' in decoded or decoded == "": - if self.is_multicalls_only: - self.controller.reduce_batch_size(self.total_calls) - raise ValueError(decoded) - # This shouldn't run unless there are issues. I'll probably delete it later. - main_logger.info(f"decoded body: {decoded}") - main_logger.info(f"exception: {responses.content._exception}") - raise + async def post(self) -> List[RawResponse]: + "this function raises `BadResponse` if a successful 'error' response was received from the rpc" + try: + response: JSONRPCBatchResponse = await session.post(self.controller.endpoint, data=self.data, loads=decode.jsonrpc_batch) + except ClientResponseError as e: + if e.message == "Payload Too Large": + logger.warning("Payload Too Large") + logger.warning("This is what was too large: %s", self.method_counts) + self.adjust_batch_size() + elif 'broken pipe' in str(e).lower(): + logger.warning("This is what broke the pipe: %s", self.method_counts) + logger.debug("caught %s for %s, reraising", e, self) + raise e + except Exception as e: + if 'broken pipe' in str(e).lower(): + logger.warning("This is what broke the pipe: %s", self.method_counts) + raise e + # NOTE: A successful response will be a list of `RawResponse` objects. + # A single `PartialResponse` implies an error. + if isinstance(response, list): + return response + # Oops, we failed. + if response.error.message == 'invalid request': + # NOT SURE IF THIS ACTUALLY RUNS, CAN WE RECEIVE THIS TYPE RESPONSE FOR A JSON BATCH? + if self.controller._time_of_request_type_change == 0: + self.controller.request_type = Request + self.controller._time_of_request_type_change = time.time() + if time.time() - self.controller._time_of_request_type_change <= 600: + logger.info("your node says the partial request was invalid but its okay, we can use the full jsonrpc spec instead") + raise response.exception def should_retry(self, e: Exception) -> bool: + """Should the JSONRPCBatch be retried based on `e`?""" # While it might look weird, f-string is faster than `str(e)`. if "No state available for block" in f"{e}": - main_logger.debug('No state available for queried block. Bisecting batch and retrying.') - return True - elif f"{e}" == "jsonrpc": - # TODO Figure out what this means and how we can prevent it. - # For now, we simply bisect and retry. + logger.debug('No state available for one of the blocks queried. Bisecting batch and retrying.') return True elif super().should_retry(e): return True return self.is_single_multicall - async def spoof_response(self, response: Union[List[RPCResponse], Exception]) -> None: - if isinstance(response, Exception): - return await await_all(call.spoof_response(response) for call in self.calls) - return await await_all( - # NOTE: For some rpc methods, the result will be a dict we can't hash during the gather. - call.spoof_response(AttributeDict(result["result"]) if isinstance(result["result"], dict) else result["result"]) # type: ignore - for call, result in zip(self.calls, response) - ) - - def validate_responses(self, responses) -> None: - # A successful response will be a list - if isinstance(responses, dict) and 'result' in responses and isinstance(responses['result'], dict) and 'message' in responses['result']: - raise ValueError(responses['result']['message']) - for response in responses: - if 'result' not in response: - raise ValueError(response) - - async def bisect_and_retry(self) -> None: - await await_all( - Multicall(self.worker, chunk[0].calls, f"json{self.jid}_{i}") # type: ignore [misc] + @set_done + async def spoof_response(self, response: List[RawResponse]) -> None: + # This means we got results. That doesn't mean they're good, but we got 'em. + for r in await asyncio.gather(*[call.spoof_response(raw) for call, raw in zip(self.calls, response)], return_exceptions=True): + # NOTE: By doing this with the exceptions we allow any successful calls to get their results sooner + # without being interrupted by the first exc in the gather and having to wait for the bisect and retry process + # TODO: stop retrying ones that succeed, that's wasteful + if isinstance(r, Exception): + raise r + + @set_done + async def bisect_and_retry(self, e: Exception) -> None: + """ + Splits up the calls of a `JSONRPCBatch` into 2 chunks, then awaits both. + If one chunk is just a single multicall, it will be handled alone, not be placed into a batch. + Calls `self._done.set()` when finished. + """ + logger.debug("%s had exception %s, retrying", self, e) + batches = [ + Multicall(self.controller, chunk[0].calls, f"json{self.jid}_{i}") # type: ignore [misc] if len(chunk) == 1 and isinstance(chunk[0], Multicall) - else JSONRPCBatch(self.worker, chunk, f"{self.jid}_{i}") + else JSONRPCBatch(self.controller, chunk, f"{self.jid}_{i}") for i, chunk in enumerate(self.bisected) if chunk - ) + ] + for batch in batches: + batch.start(cleanup=False) + for batch, result in zip(batches, await asyncio.gather(*batches, return_exceptions=True)): + if isinstance(result, Exception): + if not isinstance(result, DankMidsInternalError): + logger.error(f"That's not good, there was an exception in a {batch.__class__.__name__}. These are supposed to be handled.\n{result}\n", exc_info=True) + raise result + + def adjust_batch_size(self) -> None: + if self.is_multicalls_only: + logger.info('checking if we should reduce multicall batch size...') + self.controller.reduce_multicall_size(self.total_calls) + else: + logger.info('checking if we should reduce json batch size...') + self.controller.reduce_batch_size(len(self)) + stats.logger.devhint( + "We still need some better logic for catching these errors and using them to better optimize the batching process" + ) + + def _post_future_cleanup(self) -> None: + with self.controller.pools_closed_lock: + self.controller.pending_rpc_calls = JSONRPCBatch(self.controller) + +def _log_exception(e: Exception) -> None: + # NOTE: These errors are expected during normal use and are not indicative of any problem(s). No need to log them. + # TODO: Better filter what we choose to log here + dont_need_to_see_errs = constants.RETRY_ERRS + ['out of gas','error processing call revert', 'non_empty_data', 'invalid ether transfer'] + dont_need_to_see_errs += ["invalid request"] # We catch and correct these + stre = str(e).lower() + if any(err in stre for err in dont_need_to_see_errs): + return + logger.warning("The following exception is being logged for informational purposes and does not indicate failure:") + logger.warning(e, exc_info=True) diff --git a/dank_mids/semaphore.py b/dank_mids/semaphore.py index 28192ffc..d2da6152 100644 --- a/dank_mids/semaphore.py +++ b/dank_mids/semaphore.py @@ -1,52 +1,6 @@ -import asyncio -from typing import Union -from web3.types import RPCEndpoint -from dank_mids._config import semaphore_envs -from threading import current_thread -from typing import Dict - -class ThreadsafeSemaphore: - """ - While its a bit weird to run multiple event loops, sometimes either you or a lib you're using must do so. - When in use in threaded applications, the semaphore will not work as intended but at least your program will function. - You may need to reduce the semaphore value for multi-threaded applications. - """ - - def __init__(self, value: int) -> None: - self.default_value = value - self.semaphores: Dict[int, asyncio.Semaphore] = {} - - @property - def semaphore(self) -> asyncio.Semaphore: - tid = current_thread() - if tid not in self.semaphores: - self.semaphores[tid] = asyncio.Semaphore(self.default_value) # type: ignore [index] - return self.semaphores[tid] # type: ignore [index] - - async def __aenter__(self): - await self.semaphore.acquire() - - async def __aexit__(self, *args): - self.semaphore.release() - - -class _DummySemaphore: - async def __aenter__(self): - ... - async def __aexit__(self, *args): - ... - -class Semaphores: - def __init__(self) -> None: - self.method_semaphores = {key: ThreadsafeSemaphore(value) for key, value in semaphore_envs.items()} - self.keys = self.method_semaphores.keys() - self.dummy = _DummySemaphore() - - def __getitem__(self, method: RPCEndpoint) -> Union[ThreadsafeSemaphore, _DummySemaphore]: - for key in self.keys: - if key in method: - return self.method_semaphores[key] - return self.dummy - -method_semaphores = Semaphores() +import logging +logger = logging.getLogger(__name__) +logger.warning("dank_mids.semaphore module has been deprecated and will be removed eventually.") +logger.warning("you can now import what you need from a_sync.primitives module https://github.com/BobTheBuidler/ez-a-sync") +from a_sync.primitives.locks.semaphore import * \ No newline at end of file diff --git a/dank_mids/semaphores.py b/dank_mids/semaphores.py new file mode 100644 index 00000000..a8fecde0 --- /dev/null +++ b/dank_mids/semaphores.py @@ -0,0 +1,51 @@ + +import logging +from typing import TYPE_CHECKING, Literal, Union + +from a_sync.primitives import DummySemaphore, ThreadsafeSemaphore +from a_sync.primitives.locks.prio_semaphore import ( + _AbstractPrioritySemaphore, _PrioritySemaphoreContextManager) +from web3.types import RPCEndpoint + +if TYPE_CHECKING: + from dank_mids.controller import DankMiddlewareController + +logger = logging.getLogger(__name__) + +class _BlockSemaphoreContextManager(_PrioritySemaphoreContextManager): + _priority_name = "block" + +class BlockSemaphore(_AbstractPrioritySemaphore[str, _BlockSemaphoreContextManager]): + _context_manager_class = _BlockSemaphoreContextManager + _top_priority = -1 + def __getitem__(self, block: Union[int, str, Literal["latest", None]]) -> "_BlockSemaphoreContextManager": + return super().__getitem__( + block if isinstance(block, int) + else int(block.hex(), 16) if isinstance(block, bytes) + else int(block, 16) if isinstance(block, str) and "0x" in block + else block if block not in [None, 'latest'] # NOTE: We do this to generate an err if an unsuitable value was provided + else self._top_priority + ) + # NOTE: do we break anything if we no longer use these extra logs? + #async def acquire(self): + # logger.debug("acquiring %s", self) + # await super().acquire() + # logger.debug("acquired %s", self) + #def release(self): + # super().release() + # logger.debug("released %s", self) + + +class MethodSemaphores: + def __init__(self, controller: "DankMiddlewareController") -> None: + from dank_mids import ENVIRONMENT_VARIABLES + self.controller = controller + self.method_semaphores = { + method: (BlockSemaphore if method == "eth_call" else ThreadsafeSemaphore)(sem._value, name=f"{method} {controller}") + for method, sem in ENVIRONMENT_VARIABLES.method_semaphores.items() + } + self.keys = self.method_semaphores.keys() + self.dummy = DummySemaphore() + + def __getitem__(self, method: RPCEndpoint) -> Union[ThreadsafeSemaphore, DummySemaphore]: + return next((self.method_semaphores[key] for key in self.keys if key in method), self.dummy) diff --git a/dank_mids/stats.py b/dank_mids/stats.py new file mode 100644 index 00000000..bcc5658b --- /dev/null +++ b/dank_mids/stats.py @@ -0,0 +1,235 @@ + +# TODO: Robust and Refactor + +import asyncio +import logging +from collections import defaultdict, deque +from concurrent.futures import ProcessPoolExecutor +from time import time +from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Deque, Set, + Type, Union) + +import msgspec +from typed_envs.registry import _ENVIRONMENT_VARIABLES_SET_BY_USER +from web3.types import RPCEndpoint + +from dank_mids import ENVIRONMENT_VARIABLES as ENVS + +if TYPE_CHECKING: + from dank_mids.requests import JSONRPCBatch + from dank_mids.types import Request + +_LogLevel = Union[int, str] + +# New logging levels: +# DEBUG=10, INFO=20, +STATS = 13 +DEVHINT = 15 + +COLLECT_STATS: bool = False # TODO: enable this + +# if you're both collecting data and logging something, put the function here: + +def log_errd_batch(batch: "JSONRPCBatch") -> None: + collector.errd_batches.append(batch) + logger.devhint(f"jsonrpc batch failed\njson batch id: {batch.jid} | len: {len(batch)} | total calls: {batch.total_calls}\n" + + f"methods called: {batch.method_counts}") + +def log_duration(work_descriptor: str, start: float, *, level=STATS) -> None: + # sourcery skip: hoist-if-from-if + enabled = logger.isEnabledFor(level) + if COLLECT_STATS or enabled: + duration = time() - start + if COLLECT_STATS: + collector.durations[work_descriptor].append(duration) + if enabled: + logger._log_nocheck(level, f"{work_descriptor} took {duration}") + + +class _StatsLogger(logging.Logger): + """A specialized logger for logging stats related to dank mids""" + + @property + def enabled(self) -> bool: + """Returns `True` if level is set to `STATS` (`11`) or below.""" + self._ensure_daemon() + return self.isEnabledFor(STATS) + + # New logging levels + + def stats(self, msg, *args, **kwargs) -> None: + if self.enabled: + self._log_nocheck(STATS, msg, args, **kwargs) + + def devhint(self, msg, *args, **kwargs) -> None: + self._log(DEVHINT, msg, args, **kwargs) + + # Functions to print stats to your logs. + + def log_brownie_stats(self, *, level: _LogLevel = STATS) -> None: + self._log_fn_result(level, _Writer.brownie) + + def log_event_loop_stats(self, *, level: _LogLevel = STATS) -> None: + self._log_fn_result(level, _Writer.event_loop) + + def log_subprocess_stats(self, *, level: _LogLevel = STATS) -> None: + for pool in {ENVS.BROWNIE_ENCODER_PROCESSES, ENVS.BROWNIE_DECODER_PROCESSES, ENVS.MULTICALL_DECODER_PROCESSES}: + self._log_fn_result(level, _Writer.queue, pool) + + # Internal helpers + + def _log(self, level: _LogLevel, *args, **kwargs) -> None: + # Saves us having to do this ourselves for each custom message + if self.isEnabledFor(level): + return self._log_nocheck(level, *args, **kwargs) + + def _log_nocheck(self, level: _LogLevel, *args, **kwargs) -> None: + try: + return super()._log(level, args[0], args[1:], **kwargs) + except IndexError: + raise ValueError("Both a level and a message are required.") from None + + def _log_fn_result(self, level: _LogLevel, callable: Callable[[], str], *callable_args, **logging_kwargs) -> None: + """If `self.isEnabledFor(level)` is True, will call `callable` with your args and log the output.""" + if self.isEnabledFor(level): + return self._log_nocheck(level, callable(*callable_args), (), **logging_kwargs) + + # Daemon + + def _ensure_daemon(self) -> None: + if (ENVS.COLLECT_STATS or self.enabled) and self._daemon is None: + self._daemon = asyncio.create_task(self._stats_daemon()) + elif self._daemon.done(): + raise self._daemon.exception() + + async def _stats_daemon(self) -> None: + start = time() + time_since_notified = 0 + while True: + await asyncio.sleep(0) + now = time() + duration = now - start + collector.event_loop_times.append(duration) + start = time() + time_since_notified += duration + if time_since_notified > 300: + self.log_subprocess_stats(level=logging.INFO) + self.log_brownie_stats(level=logging.INFO) + self.log_event_loop_stats(level=logging.INFO) + time_since_notified = 0 + + # TODO: MOVE COLLECTIONS UOT OF THIS CLASS + + def log_validation_error(self, method: RPCEndpoint, e: msgspec.ValidationError) -> None: + enabled = self.isEnabledFor(DEVHINT) + if COLLECT_STATS or enabled: + collector.validation_errors[method].append(e) + if enabled: + self._log(DEVHINT, f"ValidationError when decoding response for {method}", ("This *should* not impact your script. If it does, you'll know."), e) + + def log_types(self, method: RPCEndpoint, decoded: Any) -> None: + # TODO fix this, use enabled check + types = {type(v) for v in decoded.values()} + self.devhint(f'my method and types: {method} {types}') + if list in types: + self._log_list_types(decoded.values()) + collector.types.update(types) + + def _log_list_types(self, values, level: _LogLevel = DEVHINT) -> None: + list_types = {type(_) for v in values if isinstance(v, list) for _ in v} + collector.types.update(list_types) + if self.isEnabledFor(level): + self._log(level, f"list types: {list_types}") + + +_Times = Deque[float] + +class _Collector: + def __init__(self): + """Handles the collection and computation of stats-related data.""" + self.errd_batches = deque(maxlen=500) + self.durations: DefaultDict[str, _Times] = defaultdict(lambda: deque(maxlen=50_000)) + self.types: Set[Type] = set() + self.event_loop_times: _Times = deque(maxlen=50_000) + # not implemented + self.validation_errors: DefaultDict[RPCEndpoint, Deque["Request"]] = defaultdict(lambda: deque(maxlen=100)) + + @property + def avg_loop_time(self) -> float: + return sum(collector.event_loop_times) / len(collector.event_loop_times) + @property + def count_active_brownie_calls(self) -> int: + return ENVS.BROWNIE_CALL_SEMAPHORE.default_value - ENVS.BROWNIE_CALL_SEMAPHORE.semaphore._value + @property + def count_queued_brownie_calls(self) -> int: + return len(ENVS.BROWNIE_CALL_SEMAPHORE.semaphore._waiters) + @property + def encoder_queue_len(self) -> int: + return ENVS.BROWNIE_ENCODER_PROCESSES._queue_count + @property + def decoder_queue_len(self) -> int: + return ENVS.BROWNIE_DECODER_PROCESSES._queue_count + @property + def mcall_decoder_queue_len(self) -> int: + return ENVS.MULTICALL_DECODER_PROCESSES._queue_count + + +class _Writer: + """ + `Writer` is used to turn `Collector` stats into human readable on a as-needed, JIT basis + without wasting compute or cluttering `Collector` or `StatsLogger` class definitions. + """ + def event_loop(self) -> str: + return f"Average event loop time: {collector.avg_loop_time}" + def brownie(self) -> str: + return f"{collector.count_active_brownie_calls} brownie calls are processing, {collector.count_queued_brownie_calls} are queued in {ENVS.BROWNIE_CALL_SEMAPHORE}." + def queue(self, pool: ProcessPoolExecutor) -> str: + return f"{pool} has {pool._queue_count} items in its queue" + + +class _SentryExporter: + """ + Pushes all metrics from the `metrics` dict to sentry. + Each metric value will be fetched by calling `getattr(collector, metrics[k])`. + If the result is a callable object, it will be called without args. + """ + metrics = { + "active_eth_calls": "count_active_brownie_calls", + "queued_eth_calls": "count_queued_brownie_calls", + "encoder_queue": "encoder_queue_len", + "decoder_queue": "decoder_queue_len", + "loop_time": "avg_loop_time", + } + units = {"loop_time": "seconds"} + + def push_measurements(self) -> None: + """Pushes all metrics in `self._metrics` to sentry""" + if self._exc: + raise self._exc + for tag, attr_name in self.metrics.items(): + attr = getattr(collector, attr_name) + if callable(attr): + attr = attr() + self.set_measurement(tag, attr, self.units.get(attr_name)) + + def push_envs(self) -> None: + for env, value in _ENVIRONMENT_VARIABLES_SET_BY_USER.items(): + try: + self.set_tag(env, value) + except Exception as e: + logger.warning(f"Unable to set sentry tag {env} to {value}. See {e.__class__.__name__} below:") + logger.info(e, exc_info=True) + try: + import sentry_sdk + set_tag = sentry_sdk.set_tag + set_measurement = sentry_sdk.set_measurement + _exc = None + except ImportError as e: + _exc = e + + +logger = _StatsLogger(__name__) +log = logger.stats +devhint = logger.devhint +collector = _Collector() +sentry = _SentryExporter() diff --git a/dank_mids/types.py b/dank_mids/types.py index 0f9fc2d0..4dba34c9 100644 --- a/dank_mids/types.py +++ b/dank_mids/types.py @@ -1,30 +1,182 @@ -from typing import (TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, - NewType, TypedDict, TypeVar, Union) +import logging +from time import time +from typing import (TYPE_CHECKING, Any, Callable, Coroutine, DefaultDict, Dict, + List, Literal, NewType, Optional, TypedDict, TypeVar, + Union, overload) +import msgspec from eth_typing import ChecksumAddress +from web3.datastructures import AttributeDict from web3.types import RPCEndpoint, RPCResponse +from dank_mids import constants, stats +from dank_mids._exceptions import BadResponse, PayloadTooLarge + if TYPE_CHECKING: - from dank_mids.requests import eth_call + from dank_mids.requests import Multicall ChainId = NewType("ChainId", int) BlockId = NewType("BlockId", str) BatchId = Union[int, str] -CallsToExec = Dict[BlockId, List["eth_call"]] +Multicalls = Dict[BlockId, "Multicall"] eth_callParams = TypedDict("eth_callParams", {"to": ChecksumAddress, "data": str}) OverrideParams = TypedDict("OverrideParams", {"code": str}) JsonrpcParams = List[Union[eth_callParams, BlockId, OverrideParams]] -RpcCallJson = TypedDict( - "RpcCallJson", - { - "jsonrpc": str, - "id": BatchId, - "method": str, # NOTE: must be a valid RPCEndpoint - "params": JsonrpcParams, - } -) # This type alias was introduced in web3 v5.28.0 but we like loose deps here so we recreate instead of import. AsyncMiddleware = Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]] + +list_of_stuff = List[Union[str, None, dict, list]] +dict_of_stuff = Dict[str, Union[str, None, list_of_stuff, Dict[str, Optional[Any]]]] +nested_dict_of_stuff = Dict[str, Union[str, None, list_of_stuff, dict_of_stuff]] + +class _DictStruct(msgspec.Struct): + def __getitem__(self, attr: str) -> Any: + return getattr(self, attr) + def to_dict(self) -> Dict[str, Any]: + data = {} + for field in self.__struct_fields__: + attr = getattr(self, field) + if isinstance(attr, _DictStruct): + attr = attr.to_dict() + data[field] = AttributeDict(attr) if isinstance(attr, dict) else attr + return data + +class PartialRequest(_DictStruct): + method: str + id: Union[str, int] + params: Optional[list] = None + + @property + def data(self) -> bytes: + return msgspec.json.encode(self) + +class Request(PartialRequest): + # NOTE: While technially part of a request, we can successfully make requests without including the `jsonrpc` field. + jsonrpc: Literal["2.0"] = "2.0" + +class Error(_DictStruct): + code: int + message: str + data: Optional[Any] = '' + +# some devving tools that will go away eventually +_dict_responses = set() +_str_responses = set() + +# TODO: use the types from snek +Log = Dict[str, Union[bool, str, None, List[str]]] +AccessList = List[Dict[str, Union[str, List[str]]]] +Transaction = Dict[str, Union[str, None, AccessList]] + +RETURN_TYPES = { + "eth_call": str, + "eth_chainId": str, + "eth_getCode": str, + "eth_getLogs": List[Log], + "eth_getBalance": str, + "eth_blockNumber": str, # TODO: see if we can decode this straight to an int + "eth_accounts": List[str], + "eth_getBlockByNumber": Dict[str, Union[str, List[Union[str, Transaction]]]], + "eth_getTransactionCount": str, + "eth_getTransactionByHash": Transaction, + "eth_getTransactionReceipt": Dict[str, Union[str, None, List[Log]]], + "erigon_getHeaderByNumber": Dict[str, Optional[str]], +} + +decoder_logger = logging.getLogger('dank_mids.decoder') + +class PartialResponse(_DictStruct): + result: msgspec.Raw = None # type: ignore + error: Optional[Error] = None + + @property + def exception(self) -> BadResponse: + if self.error is None: + raise AttributeError(f"{self} did not error.") + return PayloadTooLarge(self) if self.payload_too_large else BadResponse(self) + + @property + def payload_too_large(self): + any(err in self.error.message for err in constants.TOO_MUCH_DATA_ERRS) + + def to_dict(self, method: Optional[str] = None) -> Dict[str, Any]: + data = {} + for field in self.__struct_fields__: + attr = getattr(self, field) + if attr is None: + continue + if field == "result": + attr = self.decode_result(method=method, _caller=self) + if isinstance(attr, _DictStruct): + attr = attr.to_dict() + data[field] = AttributeDict(attr) if isinstance(attr, dict) and field != "error" else attr + return data + + def decode_result(self, method: Optional[str] = None, _caller = None) -> Any: + # NOTE: These must be added to the `RETURN_TYPES` constant above manually + if method and (typ := RETURN_TYPES.get(method)): + if method in ["eth_call", "eth_blockNumber", "eth_getCode", "eth_getBlockByNumber", "eth_getTransactionReceipt", "eth_getTransactionCount", "eth_getBalance", "eth_chainId", "erigon_getHeaderByNumber"]: + return msgspec.json.decode(self.result, type=typ) + try: + start = time() + decoded = msgspec.json.decode(self.result, type=typ) + if _caller: + stats.log_duration(f'decoding {type(_caller)} {method}', start) + return AttributeDict(decoded) if isinstance(decoded, dict) else decoded + except msgspec.ValidationError as e: + stats.logger.log_validation_error(self, e) + + # We have some semi-smart logic for providing decoder hints even if method not in `RETURN_TYPES` + if method: + try: + if method in _dict_responses: + decoded = AttributeDict(msgspec.json.decode(self.result, type=nested_dict_of_stuff)) + stats.logger.log_types(method, decoded) + return decoded + elif method in _str_responses: + # TODO: finish adding methods and get rid of this + stats.logger.devhint(f'Must add `{method}: str` to `RETURN_TYPES`') + return msgspec.json.decode(self.result, type=str) + except msgspec.ValidationError as e: + stats.logger.log_validation_error(method, e) + + # In this case we can provide no hints, let's let the decoder figure it out + decoded = msgspec.json.decode(self.result) + if isinstance(decoded, str): + if method: + _str_responses.add(method) + return decoded + elif isinstance(decoded, dict): + if method: + _dict_responses.add(method) + return AttributeDict(decoded) + raise TypeError(f"type {type(decoded)} is not supported.", decoded) + + +class Response(PartialResponse): + id: Optional[Union[str, int]] = None + jsonrpc: Literal["2.0"] = "2.0" + +class RawResponse: + """ + Wraps a Raw object that we know represents a Response with a `decode` helper method. + A `RawResponse` is a properly shaped response for one rpc call, received back from a jsonrpc batch request. + They represent either a successful or a failed response, stored as pre-decoded bytes. + """ + def __init__(self, raw: msgspec.Raw) -> None: + self._raw = raw + @overload + def decode(self, partial = True) -> PartialResponse:... + @overload + def decode(self, partial = False) -> Response:... + def decode(self, partial: bool = False) -> Union[Response, PartialResponse]: + return msgspec.json.decode(self._raw, type=PartialResponse if partial else Response) + +JSONRPCBatchRequest = List[Request] +# NOTE: A PartialResponse result implies a failure response from the rpc. +JSONRPCBatchResponse = Union[List[RawResponse], PartialResponse] +# We need this for proper decoding. +_JSONRPCBatchResponse = Union[List[msgspec.Raw], PartialResponse] diff --git a/dank_mids/uid.py b/dank_mids/uid.py index bebd7a99..89f22f3f 100644 --- a/dank_mids/uid.py +++ b/dank_mids/uid.py @@ -1,10 +1,13 @@ +import logging import threading +logger = logging.getLogger(__name__) + class UIDGenerator: def __init__(self) -> None: self._value: int = -1 - self.lock = threading.Lock() + self.lock = _AlertingRLock(name='uid') @property def latest(self) -> int: @@ -18,3 +21,13 @@ def next(self) -> int: new: int = self._value + 1 self._value = new return new + +class _AlertingRLock(threading._RLock): + def __init__(self, name: str) -> None: + super().__init__() + self.name = name + def acquire(self, blocking: bool = True, timeout: int = -1) -> bool: + acquired = super().acquire(blocking=False, timeout=5) + if not acquired: + logger.warning("wtf?! %s with name %s is locked!", self, self.name) + super().acquire(blocking=blocking, timeout=timeout) diff --git a/dank_mids/worker.py b/dank_mids/worker.py deleted file mode 100644 index 3532554f..00000000 --- a/dank_mids/worker.py +++ /dev/null @@ -1,122 +0,0 @@ -import asyncio -import threading -from typing import TYPE_CHECKING, Any, Generator, List - -import eth_retry -from eth_typing import ChecksumAddress -from multicall.multicall import NotSoBrightBatcher - -from dank_mids._config import GANACHE_FORK, MAX_JSONRPC_BATCH_SIZE -from dank_mids.helpers import await_all -from dank_mids.requests import JSONRPCBatch, Multicall, RPCRequest, _Batch -from dank_mids.types import CallsToExec -from dank_mids.uid import UIDGenerator - -if TYPE_CHECKING: - from dank_mids.controller import DankMiddlewareController - - -class DankWorker: - """ - Runs a second event loop in a subthread which is used to reduce congestion on the main event loop. - This allows dank_mids to better communicate with your node while you abuse it with heavy loads. - """ - def __init__(self, controller: "DankMiddlewareController") -> None: - self.controller = controller - self.target: ChecksumAddress = self.controller.multicall2 - self.batcher = NotSoBrightBatcher() - self.multicall_uid: UIDGenerator = UIDGenerator() - self.request_uid: UIDGenerator = UIDGenerator() - self.jsonrpc_batch_uid: UIDGenerator = UIDGenerator() - self.state_override_not_supported: bool = GANACHE_FORK or self.controller.chain_id == 100 # Gnosis Chain does not support state override. - self.event_loop = asyncio.new_event_loop() - self.worker_thread = threading.Thread(target=self.start) - self.worker_thread.start() - - def start(self) -> None: - """ Runs in worker thread. """ - asyncio.set_event_loop(self.event_loop) - self.event_loop.run_until_complete(self.loop()) - - async def loop(self) -> None: - """ Exits loop when main thread dies, killing worker thread. Runs in worker thread. """ - while threading.main_thread().is_alive(): - await asyncio.sleep(5) - - @eth_retry.auto_retry - async def __call__(self, *request_args: Any) -> Any: - return await self.controller.w3.eth.call(*request_args) # type: ignore - - @property - def endpoint(self) -> str: - return self.controller.w3.provider.endpoint_uri # type: ignore - - async def execute_batch(self, calls_to_exec: CallsToExec, rpc_calls: List[RPCRequest]) -> None: - """ Runs in main thread. """ - asyncio.run_coroutine_threadsafe(self._execute_batch(calls_to_exec, rpc_calls), self.event_loop).result() - - async def _execute_batch(self, calls_to_exec: CallsToExec, rpc_calls: List[RPCRequest]) -> None: - """ Runs in worker thread. """ - await DankBatch(self, calls_to_exec, rpc_calls) - - -class DankBatch: - """ A batch of jsonrpc batches. """ - def __init__(self, worker: DankWorker, eth_calls: CallsToExec, rpc_calls: List[RPCRequest]): - self.worker = worker - self.eth_calls = eth_calls - self.rpc_calls = rpc_calls - - def __await__(self) -> Generator[Any, None, Any]: - return await_all(self.coroutines).__await__() - - @property - def batcher(self) -> NotSoBrightBatcher: - return self.worker.batcher - - @property - def coroutines(self) -> Generator["_Batch", None, None]: - multicalls_to_batch: List["Multicall"] = [] - for *full_batches, remaining_calls in (self.batcher.batch_calls(calls, self.batcher.step) for calls in self.eth_calls.values()): - yield from (Multicall(self.worker, batch) for batch in full_batches) - multicalls_to_batch.append(Multicall(self.worker, remaining_calls)) - # Combine multicalls into one or more jsonrpc batches - *full_batches, working_batch = self.batch_multicalls(multicalls_to_batch) - - # Yield full batches then prepare the rest - yield from full_batches - rpc_calls_to_batch = self.rpc_calls[:] - while rpc_calls_to_batch: - if len(working_batch) >= MAX_JSONRPC_BATCH_SIZE: - yield working_batch - working_batch = JSONRPCBatch(self.worker) - working_batch.append(rpc_calls_to_batch.pop()) - if working_batch: - if working_batch.is_single_multicall: - yield working_batch[0] # type: ignore [misc] - else: - yield working_batch - - def batch_multicalls(self, multicalls: List["Multicall"]) -> Generator["JSONRPCBatch", None, None]: - """ Used to collect multicalls into batches without overwhelming the node with oversized calls. """ - multicalls.sort(key=lambda x: len(x), reverse=True) - eth_calls_in_batch = 0 - working_batch = JSONRPCBatch(self.worker) - for mcall in multicalls: - assert isinstance(mcall, Multicall) - # This would be too many eth_calls for a single multicall, lets start a new jsonrpc batch. - if eth_calls_in_batch + len(mcall) > self.batcher.step: - if working_batch: - yield working_batch - working_batch = JSONRPCBatch(self.worker, [mcall]) - eth_calls_in_batch = len(mcall) - # This can be added to the current jsonrpc batch - else: - working_batch.append(mcall) - eth_calls_in_batch += len(mcall) - if len(working_batch) >= MAX_JSONRPC_BATCH_SIZE: - # There are more than `MAX_JSONRPC_BATCH_SIZE` rpc calls packed into this batch, let's start a new one - yield working_batch - working_batch = JSONRPCBatch(self.worker) - eth_calls_in_batch = 0 - yield working_batch diff --git a/requirements.txt b/requirements.txt index b34835a9..86966bc6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ -bobs_lazy_logging>=0.0.2 eth_retry>=0.1.15,<0.2 +ez-a-sync>=0.6.3 +msgspec multicall>=0.6.2 +typed-envs>=0.0.2 # web3 5.29 thru 5.31.2 are unable to use async provider in subthreads -web3>=5.27,!=5.29.*,!=5.30.*,!=5.31.1,!=5.31.2 +web3>=5.27,!=5.29.*,!=5.30.*,!=5.31.1,!=5.31.2 \ No newline at end of file diff --git a/tests/test_brownie_patch.py b/tests/test_brownie_patch.py index d49b057a..8df727b4 100644 --- a/tests/test_brownie_patch.py +++ b/tests/test_brownie_patch.py @@ -1,10 +1,12 @@ +# sourcery skip: no-loop-in-tests +import asyncio from brownie import Contract, web3 -from dank_mids.brownie_patch import patch_contract -from dank_mids.brownie_patch.call import _patch_call -from dank_mids import setup_dank_w3_from_sync from multicall.utils import await_awaitable +from dank_mids import setup_dank_w3_from_sync +from dank_mids.brownie_patch import patch_contract +from dank_mids.brownie_patch.call import _patch_call from tests.fixtures import dank_w3 @@ -15,6 +17,14 @@ def test_patch_call(): assert hasattr(weth.totalSupply, 'coroutine') assert await_awaitable(weth.totalSupply.coroutine(block_identifier=13_000_000)) == 6620041514474872981393155 +def test_gather(): + # must use from_explorer for gh testing workflow + weth = Contract.from_explorer('0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2') + _patch_call(weth.totalSupply, dank_w3) + assert hasattr(weth.totalSupply, 'coroutine') + for result in await_awaitable(asyncio.gather(*[weth.totalSupply.coroutine(block_identifier=13_000_000) for _ in range(10_000)])): + assert result == 6620041514474872981393155 + def test_patch_contract(): # ContractCall # must use from_explorer for gh testing workflow diff --git a/tests/test_dank_mids.py b/tests/test_dank_mids.py index 22c350fe..0046baab 100644 --- a/tests/test_dank_mids.py +++ b/tests/test_dank_mids.py @@ -18,14 +18,11 @@ def _get_controller(): return instances[chain.id][0] -def _get_worker(): - return _get_controller().worker - def test_dank_middleware(): await_awaitable(gather(BIG_WORK)) cid = _get_controller().call_uid.latest - mid = _get_worker().multicall_uid.latest - rid = _get_worker().request_uid.latest + mid = _get_controller().multicall_uid.latest + rid = _get_controller().request_uid.latest assert cid, "The DankMiddlewareController did not process any calls." assert mid, "The DankMiddlewareController did not process any batches." assert rid, "The DankMiddlewareController did not process any requests." @@ -49,10 +46,10 @@ def test_next_cid(): assert _get_controller().call_uid.next + 1 == _get_controller().call_uid.next def test_next_mid(): - assert _get_worker().request_uid.next + 1 == _get_worker().request_uid.next + assert _get_controller().request_uid.next + 1 == _get_controller().request_uid.next def test_next_bid(): - assert _get_worker().multicall_uid.next + 1 == _get_worker().multicall_uid.next + assert _get_controller().multicall_uid.next + 1 == _get_controller().multicall_uid.next def test_other_methods(): work = [dank_w3.eth.get_block_number() for i in range(50)]