Skip to content

Commit

Permalink
fix(mypy): fix type errs (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Apr 24, 2024
1 parent 6ff6e4a commit 1d0d3e6
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 44 deletions.
2 changes: 1 addition & 1 deletion dank_mids/_demo_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ class DummyLogger:
def info(self, *args: Any, **kwargs: Any) -> None:
...

demo_logger = logging.getLogger("dank_mids.demo") if ENVIRONMENT_VARIABLES.DEMO_MODE else DummyLogger()
demo_logger = logging.getLogger("dank_mids.demo") if ENVIRONMENT_VARIABLES.DEMO_MODE else DummyLogger() # type: ignore [attr-defined]
3 changes: 2 additions & 1 deletion dank_mids/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ class PayloadTooLarge(BadResponse):
pass

class ExceedsMaxBatchSize(BadResponse):
"""A special exception that occurs when you post a batch which exceeds the maximum batch size for the rpc."""
@property
def limit(self) -> int:
return int(re.search(r'batch limit (\d+) exceeded', self.response.error.message).group(1))
return int(re.search(r'batch limit (\d+) exceeded', self.response.error.message).group(1)) # type: ignore [union-attr]

class DankMidsClientResponseError(ClientResponseError):
"""A wrapper around the standard aiohttp ClientResponseError that attaches the request that generated the error."""
Expand Down
28 changes: 14 additions & 14 deletions dank_mids/_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def response(self) -> _Response:
async def get_response(self) -> Optional[_Response]:
pass

async def _debug_daemon(self) -> NoReturn:
async def _debug_daemon(self) -> None:
while not self._done.is_set():
await asyncio.sleep(60)
if not self._done.is_set():
Expand Down Expand Up @@ -233,7 +233,7 @@ async def spoof_response(self, data: Union[RawResponse, bytes, Exception]) -> No
if data.response.error.message.lower() in ['invalid request', 'parse error']:
if self.controller._time_of_request_type_change == 0:
self.controller.request_type = Request
self.controller._time_of_request_type_change = time.time()
self.controller._time_of_request_type_change = int(time.time())
if time.time() - self.controller._time_of_request_type_change <= 600:
logger.debug("your node says the partial request was invalid but its okay, we can use the full jsonrpc spec instead")
self._response = await self.create_duplicate()
Expand Down Expand Up @@ -266,7 +266,7 @@ def semaphore(self) -> a_sync.Semaphore:
semaphore = semaphore[self.params[1]]
return semaphore

async def create_duplicate(self) -> Self: # Not actually self, but for typing purposes it is.
async def create_duplicate(self) -> RPCResponse: # Not actually self, but for typing purposes it is.
# We need to make room since the stalled call is still holding the semaphore
self.semaphore.release()
# We need to check the semaphore again to ensure we have the right context manager, soon but not right away.
Expand Down Expand Up @@ -566,19 +566,19 @@ async def spoof_response(self, data: Union[RawResponse, Exception]) -> None:

async def decode(self, data: PartialResponse) -> List[Tuple[bool, bytes]]:
start = time.time()
if ENVS.OPERATION_MODE.infura:
if ENVS.OPERATION_MODE.infura: # type: ignore [attr-defined]
retval = mcall_decode(data)
else:
try: # NOTE: Quickly check for length without counting each item with `len`.
if not ENVS.OPERATION_MODE.application:
if not ENVS.OPERATION_MODE.application: # type: ignore [attr-defined]
self[100]
retval = await ENVS.MULTICALL_DECODER_PROCESSES.run(mcall_decode, data)
retval = await ENVS.MULTICALL_DECODER_PROCESSES.run(mcall_decode, data) # type: ignore [attr-defined]
except IndexError:
retval = mcall_decode(data)
except BrokenProcessPool:
# TODO: Move this somewhere else
logger.critical("Oh fuck, you broke the %s while decoding %s", ENVS.MULTICALL_DECODER_PROCESSES, data)
ENVS.MULTICALL_DECODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.MULTICALL_DECODER_PROCESSES._max_workers)
logger.critical("Oh fuck, you broke the %s while decoding %s", ENVS.MULTICALL_DECODER_PROCESSES, data) # type: ignore [attr-defined]
ENVS.MULTICALL_DECODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.MULTICALL_DECODER_PROCESSES._max_workers) # type: ignore [attr-defined]
retval = mcall_decode(data)
stats.log_duration(f"multicall decoding for {len(self)} calls", start)
# Raise any Exceptions that may have come out of the process pool.
Expand Down Expand Up @@ -672,15 +672,15 @@ def total_calls(self) -> int:
@property
def is_full(self) -> bool:
with self._lock:
return self.total_calls >= self.controller.batcher.step or len(self) >= ENVS.MAX_JSONRPC_BATCH_SIZE
return self.total_calls >= self.controller.batcher.step or len(self) >= ENVS.MAX_JSONRPC_BATCH_SIZE # type: ignore [attr-defined]

async def get_response(self) -> None:
if self._started:
logger.warning(f"{self} exiting early. This shouldn't really happen bro")
return
self._started = True
rid = self.controller.request_uid.next
if ENVS.DEMO_MODE:
if ENVS.DEMO_MODE: # type: ignore [attr-defined]
# When demo mode is disabled, we can save some CPU time by skipping this sum
demo_logger.info(f'request {rid} for jsonrpc batch {self.jid} ({sum(len(batch) for batch in self.calls)} calls) starting') # type: ignore
try:
Expand Down Expand Up @@ -732,13 +732,13 @@ async def post(self) -> List[RawResponse]:
elif 'broken pipe' in str(e).lower():
logger.warning("This is what broke the pipe: %s", self.method_counts)
logger.debug("caught %s for %s, reraising", e, self)
if ENVS.DEBUG:
if ENVS.DEBUG: # type: ignore [attr-defined]
_debugging.failures.record(self.controller.chain_id, e, type(self).__name__, self.uid, len(self), self.data)
raise e
except Exception as e:
if 'broken pipe' in str(e).lower():
logger.warning("This is what broke the pipe: %s", self.method_counts)
if ENVS.DEBUG:
if ENVS.DEBUG: # type: ignore [attr-defined]
_debugging.failures.record(self.controller.chain_id, e, type(self).__name__, self.uid, len(self), self.data)
raise e
# NOTE: A successful response will be a list of `RawResponse` objects.
Expand Down Expand Up @@ -846,7 +846,7 @@ def _log_exception(e: Exception) -> bool:

stre = str(e).lower()
if any(err in stre for err in dont_need_to_see_errs):
return ENVS.DEBUG
return ENVS.DEBUG # type: ignore [attr-defined]
logger.warning("The following exception is being logged for informational purposes and does not indicate failure:")
logger.warning(e, exc_info=True)
return ENVS.DEBUG
return ENVS.DEBUG # type: ignore [attr-defined]
4 changes: 2 additions & 2 deletions dank_mids/brownie_patch/_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ async def coroutine( # type: ignore [empty-body]
"""
if override:
raise ValueError("Cannot use state override with `coroutine`.")
async with ENVS.BROWNIE_ENCODER_SEMAPHORE[block_identifier]:
async with ENVS.BROWNIE_ENCODER_SEMAPHORE[block_identifier]: # type: ignore [attr-defined]
data = await self._encode_input(self, self._len_inputs, self._prep_request_data, *args)
async with ENVS.BROWNIE_CALL_SEMAPHORE[block_identifier]:
async with ENVS.BROWNIE_CALL_SEMAPHORE[block_identifier]: # type: ignore [attr-defined]
output = await self._web3.eth.call({"to": self._address, "data": data}, block_identifier)
try:
decoded = await self._decode_output(self, output)
Expand Down
18 changes: 9 additions & 9 deletions dank_mids/brownie_patch/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from dank_mids.exceptions import Revert

logger = logging.getLogger(__name__)
encode = lambda self, *args: ENVS.BROWNIE_ENCODER_PROCESSES.run(__encode_input, self.abi, self.signature, *args)
decode = lambda self, data: ENVS.BROWNIE_DECODER_PROCESSES.run(__decode_output, data, self.abi)
encode = lambda self, *args: ENVS.BROWNIE_ENCODER_PROCESSES.run(__encode_input, self.abi, self.signature, *args) # type: ignore [attr-defined]
decode = lambda self, data: ENVS.BROWNIE_DECODER_PROCESSES.run(__decode_output, data, self.abi) # type: ignore [attr-defined]

def _patch_call(call: ContractCall, w3: Web3) -> None:
call._skip_decoder_proc_pool = call._address in _skip_proc_pool
Expand All @@ -37,7 +37,7 @@ def _patch_call(call: ContractCall, w3: Web3) -> None:

@functools.lru_cache
def _get_coroutine_fn(w3: Web3, len_inputs: int):
if ENVS.OPERATION_MODE.application or len_inputs:
if ENVS.OPERATION_MODE.application or len_inputs: # type: ignore [attr-defined]
get_request_data = encode
else:
get_request_data = _request_data_no_args
Expand All @@ -51,9 +51,9 @@ async def coroutine(
) -> Any:
if override:
raise ValueError("Cannot use state override with `coroutine`.")
async with ENVS.BROWNIE_ENCODER_SEMAPHORE[block_identifier]:
async with ENVS.BROWNIE_ENCODER_SEMAPHORE[block_identifier]: # type: ignore [attr-defined]
data = await encode_input(self, len_inputs, get_request_data, *args)
async with ENVS.BROWNIE_CALL_SEMAPHORE[block_identifier]:
async with ENVS.BROWNIE_CALL_SEMAPHORE[block_identifier]: # type: ignore [attr-defined]
output = await w3.eth.call({"to": self._address, "data": data}, block_identifier)
try:
decoded = await decode_output(self, output)
Expand All @@ -80,9 +80,9 @@ async def encode_input(call: ContractCall, len_inputs, get_request_data, *args)
data = __encode_input(call.abi, call.signature, *args)
# TODO: move this somewhere else
except BrokenProcessPool:
logger.critical("Oh fuck, you broke the %s while decoding %s with abi %s", ENVS.BROWNIE_ENCODER_PROCESSES, data, call.abi)
logger.critical("Oh fuck, you broke the %s while decoding %s with abi %s", ENVS.BROWNIE_ENCODER_PROCESSES, data, call.abi) # type: ignore [attr-defined]
# Let's fix that right up
ENVS.BROWNIE_ENCODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.BROWNIE_ENCODER_PROCESSES._max_workers)
ENVS.BROWNIE_ENCODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.BROWNIE_ENCODER_PROCESSES._max_workers) # type: ignore [attr-defined]
data = __encode_input(call.abi, call.signature, *args) if len_inputs else call.signature
except PicklingError: # But if that fails, don't worry. I got you.
data = __encode_input(call.abi, call.signature, *args) if len_inputs else call.signature
Expand All @@ -104,8 +104,8 @@ async def decode_output(call: ContractCall, data: bytes) -> Any:
# TODO: move this somewhere else
except BrokenProcessPool:
# Let's fix that right up
logger.critical("Oh fuck, you broke the %s while decoding %s with abi %s", ENVS.BROWNIE_DECODER_PROCESSES, data, call.abi)
ENVS.BROWNIE_DECODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.BROWNIE_DECODER_PROCESSES._max_workers)
logger.critical("Oh fuck, you broke the %s while decoding %s with abi %s", ENVS.BROWNIE_DECODER_PROCESSES, data, call.abi) # type: ignore [attr-defined]
ENVS.BROWNIE_DECODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.BROWNIE_DECODER_PROCESSES._max_workers) # type: ignore [attr-defined]
decoded = __decode_output(data, call.abi)
# We have to do it like this so we don't break the process pool.
if isinstance(decoded, Exception):
Expand Down
10 changes: 5 additions & 5 deletions dank_mids/helpers/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def __new__(cls, value, phrase, description=''):
HTTPStatusExtended = IntEnum('HTTPStatusExtended', [(i.name, i.value) for i in chain(http.HTTPStatus, _HTTPStatusExtension)])

RETRY_FOR_CODES = {
HTTPStatusExtended.BAD_GATEWAY,
HTTPStatusExtended.WEB_SERVER_IS_RETURNING_AN_UNKNOWN_ERROR,
HTTPStatusExtended.CLOUDFLARE_CONNECTION_TIMEOUT,
HTTPStatusExtended.CLOUDFLARE_TIMEOUT,
HTTPStatusExtended.BAD_GATEWAY, # type: ignore [attr-defined]
HTTPStatusExtended.WEB_SERVER_IS_RETURNING_AN_UNKNOWN_ERROR, # type: ignore [attr-defined]
HTTPStatusExtended.CLOUDFLARE_CONNECTION_TIMEOUT, # type: ignore [attr-defined]
HTTPStatusExtended.CLOUDFLARE_TIMEOUT, # type: ignore [attr-defined]
}

limiter = AsyncLimiter(5, 0.1) # 50 requests/second
Expand Down Expand Up @@ -137,7 +137,7 @@ async def _get_session_for_thread(thread_ident: int) -> ClientSession:
return ClientSession(
connector = TCPConnector(limit=32),
headers = {'content-type': 'application/json'},
timeout = ClientTimeout(ENVIRONMENT_VARIABLES.AIOHTTP_TIMEOUT), # type: ignore [arg-type]
timeout = ClientTimeout(ENVIRONMENT_VARIABLES.AIOHTTP_TIMEOUT), # type: ignore [arg-type, attr-defined]
raise_for_status = True,
)

Expand Down
14 changes: 7 additions & 7 deletions dank_mids/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def log_event_loop_stats(self, *, level: _LogLevel = STATS) -> None:
self._log_fn_result(level, _Writer.event_loop)

def log_subprocess_stats(self, *, level: _LogLevel = STATS) -> None:
for pool in {ENVS.BROWNIE_ENCODER_PROCESSES, ENVS.BROWNIE_DECODER_PROCESSES, ENVS.MULTICALL_DECODER_PROCESSES}:
for pool in {ENVS.BROWNIE_ENCODER_PROCESSES, ENVS.BROWNIE_DECODER_PROCESSES, ENVS.MULTICALL_DECODER_PROCESSES}: # type: ignore [attr-defined]
self._log_fn_result(level, _Writer.queue, pool)

# Internal helpers
Expand All @@ -97,7 +97,7 @@ def _log_fn_result(self, level: _LogLevel, callable: Callable[[], str], *callabl
# Daemon

def _ensure_daemon(self) -> None:
if (ENVS.COLLECT_STATS or self.enabled) and self._daemon is None:
if (ENVS.COLLECT_STATS or self.enabled) and self._daemon is None: # type: ignore [attr-defined]
self._daemon = asyncio.create_task(self._stats_daemon())
elif self._daemon.done():
raise self._daemon.exception()
Expand Down Expand Up @@ -159,19 +159,19 @@ def avg_loop_time(self) -> float:
return sum(collector.event_loop_times) / len(collector.event_loop_times)
@property
def count_active_brownie_calls(self) -> int:
return ENVS.BROWNIE_CALL_SEMAPHORE.default_value - ENVS.BROWNIE_CALL_SEMAPHORE.semaphore._value
return ENVS.BROWNIE_CALL_SEMAPHORE.default_value - ENVS.BROWNIE_CALL_SEMAPHORE.semaphore._value # type: ignore [attr-defined]
@property
def count_queued_brownie_calls(self) -> int:
return len(ENVS.BROWNIE_CALL_SEMAPHORE.semaphore._waiters)
return len(ENVS.BROWNIE_CALL_SEMAPHORE.semaphore._waiters) # type: ignore [attr-defined]
@property
def encoder_queue_len(self) -> int:
return ENVS.BROWNIE_ENCODER_PROCESSES._queue_count
return ENVS.BROWNIE_ENCODER_PROCESSES._queue_count # type: ignore [attr-defined]
@property
def decoder_queue_len(self) -> int:
return ENVS.BROWNIE_DECODER_PROCESSES._queue_count
return ENVS.BROWNIE_DECODER_PROCESSES._queue_count # type: ignore [attr-defined]
@property
def mcall_decoder_queue_len(self) -> int:
return ENVS.MULTICALL_DECODER_PROCESSES._queue_count
return ENVS.MULTICALL_DECODER_PROCESSES._queue_count # type: ignore [attr-defined]


class _Writer:
Expand Down
13 changes: 9 additions & 4 deletions examples/dank_brownie_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import List

import dank_mids

from web3.types import Timestamp

# For the purpose of this example, we will define the Uniswap pools we want to get data from
# and the blocks at which we wish to fetch data.
Expand Down Expand Up @@ -60,6 +60,11 @@ async def get_tokens_for_pool(pool: dank_mids.Contract):
# To batch other rpc calls, use the `dank_mids.eth` object like you would brownie's `web3.eth` object.
# This object wraps the connected brownie Web3 instance and injects the dank middleware for batching

async def get_timestamp_at_block(block: int):
block = await dank_mids.eth.get_block(block)
return block.timestamp
async def get_timestamp_at_block(block: int) -> Timestamp:
data = await dank_mids.eth.get_block(block)
# dank mids will turn all dict responses into attr dicts for easier key lookup
# the syntax below will work but won't type check correctly
not_typed = data.timestamp # type: ignore [attr-defined]
# the syntax below is more annoying to type out everywhere but will work with type checkers
typed = data['timestamp']
return typed
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@

[tool.mypy]
exclude = ["build/","cache/","env/","tests/"]
exclude = ["build/","cache/","env/","tests/"]
ignore_missing_imports = true
check_untyped_defs = true
disable_error_code = ["return"]

0 comments on commit 1d0d3e6

Please sign in to comment.