Skip to content

Commit

Permalink
chore: run black and isort on brownie_patch module (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Nov 3, 2024
1 parent 8e77756 commit 2855502
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 78 deletions.
9 changes: 7 additions & 2 deletions dank_mids/brownie_patch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@

from contextlib import suppress

from web3.eth import AsyncEth

from dank_mids.helpers import setup_dank_w3_from_sync
from dank_mids.brownie_patch.types import DankContractCall, DankContractMethod, DankContractTx, DankOverloadedMethod
from dank_mids.brownie_patch.types import (
DankContractCall,
DankContractMethod,
DankContractTx,
DankOverloadedMethod,
)

__all__ = ["DankContractCall", "DankContractMethod", "DankContractTx", "DankOverloadedMethod"]

Expand All @@ -27,6 +31,7 @@
# If using dank_mids wih brownie, and brownie is connected when this file executes, you will get a 'dank_w3' async web3 instance with Dank Middleware here.
with suppress(ImportError):
from brownie import network, web3

if network.is_connected():
from dank_mids.brownie_patch.contract import Contract, patch_contract

Expand Down
7 changes: 3 additions & 4 deletions dank_mids/brownie_patch/_abi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import functools
from typing import Any

Expand All @@ -9,7 +8,7 @@
class FunctionABI:
"""
A singleton class to hold function ABI information.
This class uses the lru_cache decorator to ensure only one instance is created
for each unique set of ABI parameters, optimizing memory usage and performance.
"""
Expand All @@ -23,7 +22,7 @@ def __init__(self, **abi: Any):
Args:
**abi: Keyword arguments representing the ABI of the function.
"""

self.abi = abi
"""
The complete ABI (Application Binary Interface) of the function.
Expand All @@ -40,4 +39,4 @@ def __init__(self, **abi: Any):
"""
The function selector (4-byte signature) of the function.
This is used in Ethereum transactions to identify which function to call.
"""
"""
51 changes: 33 additions & 18 deletions dank_mids/brownie_patch/_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from decimal import Decimal
from typing import Any, Awaitable, Callable, Dict, Generic, Iterable, List, Optional, TypeVar

from brownie.typing import AccountsType
from brownie.convert.datatypes import EthAddress
from brownie.typing import AccountsType
from eth_abi.exceptions import InsufficientDataBytes
from hexbytes.main import BytesLike

Expand All @@ -14,6 +14,7 @@

_EVMType = TypeVar("_EVMType")


class _DankMethodMixin(Generic[_EVMType]):
"""
A mixin class that is used internally to enhance Brownie's contract methods
Expand All @@ -31,15 +32,15 @@ class _DankMethodMixin(Generic[_EVMType]):
def __await__(self):
"""
Allow the contract method to be awaited.
This method enables using 'await' on the contract method, which will call
the method without arguments at the latest block and return the result.
"""
return self.coroutine().__await__()

async def map(
self,
args: Iterable[Any],
self,
args: Iterable[Any],
block_identifier: Optional[int] = None,
decimals: Optional[int] = None,
) -> List[_EVMType]:
Expand All @@ -62,7 +63,7 @@ async def map(
def abi(self) -> dict:
"""
The ABI of the contract function.
This property provides access to the complete ABI dictionary of the function.
"""
return self._abi.abi
Expand All @@ -71,16 +72,16 @@ def abi(self) -> dict:
def signature(self) -> str:
"""
The function signature.
This property returns the unique signature of the contract function,
which is used to identify the function in transactions.
"""
return self._abi.signature

async def coroutine( # type: ignore [empty-body]
self,
*args: Any,
block_identifier: Optional[int] = None,
self,
*args: Any,
block_identifier: Optional[int] = None,
decimals: Optional[int] = None,
override: Optional[Dict[str, str]] = None,
) -> _EVMType:
Expand All @@ -99,30 +100,40 @@ async def coroutine( # type: ignore [empty-body]
The result of the contract method call.
"""
raise NotImplementedError

@property
def _input_sig(self) -> str:
return self._abi.input_sig

@functools.cached_property
def _len_inputs(self) -> int:
return len(self.abi['inputs'])
return len(self.abi["inputs"])

@functools.cached_property
def _skip_decoder_proc_pool(self) -> bool:
from dank_mids.brownie_patch.call import _skip_proc_pool

return self._address in _skip_proc_pool

@functools.cached_property
def _web3(cls) -> DankWeb3:
from dank_mids import web3

return web3

@functools.cached_property
def _prep_request_data(self) -> Callable[..., Awaitable[BytesLike]]:
from dank_mids.brownie_patch import call

if ENVS.OPERATION_MODE.application or self._len_inputs: # type: ignore [attr-defined]
return call.encode
else:
return call._request_data_no_args


class _DankMethod(_DankMethodMixin):
__slots__ = "_address", "_abi", "_name", "_owner", "natspec", "_encode_input", "_decode_output"

def __init__(
self,
address: str,
Expand All @@ -142,26 +153,28 @@ def __init__(

self.natspec = natspec or {}
"""The NatSpec documentation for the function."""

# TODO: refactor this
from dank_mids.brownie_patch import call

self._encode_input = call.encode_input
self._decode_output = call.decode_output

async def coroutine( # type: ignore [empty-body]
self,
*args: Any,
block_identifier: Optional[int] = None,
self,
*args: Any,
block_identifier: Optional[int] = None,
decimals: Optional[int] = None,
override: Optional[Dict[str, str]] = None,
) -> _EVMType:
"""
Asynchronously call the contract method via dank mids and await the result.
Arguments:
- *args: The arguments for the contract method.
- block_identifier (optional): The block at which the chain will be read. If not provided, will read the chain at latest block.
- decimals (optional): if provided, the output will be `result / 10 ** decimals`
Returns:
- Whatever the node sends back as the output for this contract method.
"""
Expand All @@ -170,9 +183,11 @@ async def coroutine( # type: ignore [empty-body]
async with ENVS.BROWNIE_ENCODER_SEMAPHORE[block_identifier]: # type: ignore [attr-defined,index]
data = await self._encode_input(self, self._len_inputs, self._prep_request_data, *args)
async with ENVS.BROWNIE_CALL_SEMAPHORE[block_identifier]: # type: ignore [attr-defined,index]
output = await self._web3.eth.call({"to": self._address, "data": data}, block_identifier)
output = await self._web3.eth.call(
{"to": self._address, "data": data}, block_identifier
)
try:
decoded = await self._decode_output(self, output)
except InsufficientDataBytes as e:
raise InsufficientDataBytes(str(e), self, self._address, output) from e
return decoded if decimals is None else decoded / 10 ** Decimal(decimals)
return decoded if decimals is None else decoded / 10 ** Decimal(decimals)
33 changes: 22 additions & 11 deletions dank_mids/brownie_patch/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"""
decode = lambda self, data: ENVS.BROWNIE_DECODER_PROCESSES.run(__decode_output, data, self.abi) # type: ignore [attr-defined]


def _patch_call(call: ContractCall, w3: DankWeb3) -> None:
"""
Patch a Brownie ContractCall to enable asynchronous use via dank_mids for batching.
Expand All @@ -53,22 +54,23 @@ def _patch_call(call: ContractCall, w3: DankWeb3) -> None:
A patched version of the ContractCall with enhanced functionality.
"""
call._skip_decoder_proc_pool = call._address in _skip_proc_pool
call.coroutine = MethodType(_get_coroutine_fn(w3, len(call.abi['inputs'])), call)
call.coroutine = MethodType(_get_coroutine_fn(w3, len(call.abi["inputs"])), call)
call.__await__ = MethodType(_call_no_args, call)


@functools.lru_cache
def _get_coroutine_fn(w3: DankWeb3, len_inputs: int):
if ENVS.OPERATION_MODE.application or len_inputs: # type: ignore [attr-defined]
get_request_data = encode
else:
get_request_data = _request_data_no_args # type: ignore [assignment]

async def coroutine(
self: ContractCall,
*args: Any,
block_identifier: Optional[BlockIdentifier] = None,
decimals: Optional[int] = None,
override: Optional[Dict[str, str]] = None
override: Optional[Dict[str, str]] = None,
) -> Any:
if override:
raise ValueError("Cannot use state override with `coroutine`.")
Expand All @@ -80,21 +82,25 @@ async def coroutine(
decoded = await decode_output(self, output)
except InsufficientDataBytes as e:
raise InsufficientDataBytes(str(e), self, self._address, output) from e

return decoded if decimals is None else decoded / 10 ** Decimal(decimals)

return coroutine


def _call_no_args(self: ContractMethod):
"""Asynchronously call `self` with no arguments at the latest block."""
return self.coroutine().__await__()


async def encode_input(call: ContractCall, len_inputs, get_request_data, *args) -> HexStr:
if any(isinstance(arg, Contract) for arg in args) or any(hasattr(arg, "__contains__") for arg in args): # We will just assume containers contain a Contract object until we have a better way to handle this
if any(isinstance(arg, Contract) for arg in args) or any(
hasattr(arg, "__contains__") for arg in args
): # We will just assume containers contain a Contract object until we have a better way to handle this
# We can't unpickle these because of the added `coroutine` method.
data = __encode_input(call.abi, call.signature, *args)
else:
try: # We're better off sending these to the subprocess so they don't clog up the event loop.
try: # We're better off sending these to the subprocess so they don't clog up the event loop.
data = await get_request_data(call, *args)
except (AttributeError, TypeError):
# These occur when we have issues pickling an object, but that's fine, we can do it sync.
Expand Down Expand Up @@ -132,20 +138,23 @@ async def decode_output(call: ContractCall, data: bytes) -> Any:
if isinstance(decoded, Exception):
raise decoded
return decoded
except AttributeError as e:
except AttributeError as e:
# NOTE: Not sure why this happens as we set the attr while patching the call but w/e, this works for now
if not str(e).endswith(" object has no attribute '_skip_decoder_proc_pool'"):
raise
logger.debug("DEBUG ME BRO: %s", e)
call._skip_decoder_proc_pool = call._address in _skip_proc_pool
return await decode_output(call, data)


async def _request_data_no_args(call: ContractCall) -> HexStr:
return call.signature


# These methods were renamed in eth-abi 4.0.0
__eth_abi_encode = eth_abi.encode if hasattr(eth_abi, 'encode') else eth_abi.encode_abi
__eth_abi_decode = eth_abi.decode if hasattr(eth_abi, 'decode') else eth_abi.decode_abi
__eth_abi_encode = eth_abi.encode if hasattr(eth_abi, "encode") else eth_abi.encode_abi
__eth_abi_decode = eth_abi.decode if hasattr(eth_abi, "decode") else eth_abi.decode_abi


def __encode_input(abi: Dict[str, Any], signature: str, *args: Any) -> Union[HexStr, Exception]:
try:
Expand All @@ -155,6 +164,7 @@ def __encode_input(abi: Dict[str, Any], signature: str, *args: Any) -> Union[Hex
except Exception as e:
return e


_skip_proc_pool = {"0xcA11bde05977b3631167028862bE2a173976CA11"} # multicall3
# NOTE: retry 429 errors if running multiple services on same rpc
while True:
Expand All @@ -166,7 +176,8 @@ def __encode_input(abi: Dict[str, Any], signature: str, *args: Any) -> Union[Hex
raise
if multicall2 := MULTICALL2_ADDRESSES.get(chainid, None):
_skip_proc_pool.add(to_checksum_address(multicall2))



def __decode_output(hexstr: BytesLike, abi: Dict[str, Any]) -> Any:
try:
types_list = get_type_strings(abi["outputs"])
Expand All @@ -178,6 +189,7 @@ def __decode_output(hexstr: BytesLike, abi: Dict[str, Any]) -> Any:
except Exception as e:
return e


def __validate_output(abi: Dict[str, Any], hexstr: BytesLike):
try:
selector = HexBytes(hexstr)[:4].hex()
Expand All @@ -200,4 +212,3 @@ def __validate_output(abi: Dict[str, Any], hexstr: BytesLike):
raise VirtualMachineError(e) from None
except:
raise e from e.__cause__

Loading

0 comments on commit 2855502

Please sign in to comment.