Skip to content

Commit

Permalink
feat: decimals kwarg for coroutine
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Apr 1, 2024
1 parent 3f9bf16 commit b32e48d
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 9 deletions.
18 changes: 13 additions & 5 deletions dank_mids/brownie_patch/call.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

import functools
import logging
from concurrent.futures.process import BrokenProcessPool
from functools import lru_cache
from decimal import Decimal
from pickle import PicklingError
from types import MethodType
from typing import Any, Dict, Optional, Tuple, Union
Expand Down Expand Up @@ -30,18 +31,20 @@
def _patch_call(call: ContractCall, w3: Web3) -> None:
call._skip_decoder_proc_pool = call._address in _skip_proc_pool
call.coroutine = MethodType(_get_coroutine_fn(w3, len(call.abi['inputs'])), call)
call.__await__ = MethodType(__await_no_args__, call)

@lru_cache
@functools.lru_cache
def _get_coroutine_fn(w3: Web3, len_inputs: int):
if ENVS.OPERATION_MODE.application:
if ENVS.OPERATION_MODE.application or len_inputs:
get_request_data = encode
else:
get_request_data = encode if len_inputs else __request_data_no_args
get_request_data = __request_data_no_args

async def coroutine(
self: ContractCall,
*args: Tuple[Any,...],
block_identifier: Optional[Union[int, str, bytes]] = None,
decimals: Optional[int] = None,
override: Optional[Dict[str, str]] = None
) -> Any:
if override:
Expand All @@ -51,12 +54,17 @@ async def coroutine(
async with ENVS.BROWNIE_CALL_SEMAPHORE[block_identifier]:
output = await w3.eth.call({"to": self._address, "data": data}, block_identifier)
try:
return await decode_output(self, output)
decoded = await decode_output(self, output)
except InsufficientDataBytes as e:
raise InsufficientDataBytes(str(e), self, self._address, output) from e

return decoded if decimals is None else decoded / 10 ** Decimal(decimals)

return coroutine

def __await_no_args__(self: ContractMethod):
return self.coroutine().__await__()

async def encode_input(call: ContractCall, len_inputs, get_request_data, *args) -> bytes:
if any(isinstance(arg, Contract) for arg in args) or any(hasattr(arg, "__contains__") for arg in args): # We will just assume containers contain a Contract object until we have a better way to handle this
# We can't unpickle these because of the added `coroutine` method.
Expand Down
3 changes: 1 addition & 2 deletions dank_mids/brownie_patch/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from dank_mids.brownie_patch.call import _get_coroutine_fn
from dank_mids.brownie_patch.overloaded import _patch_overloaded_method
from dank_mids.brownie_patch.types import ContractMethod


class Contract(brownie.Contract):
Expand All @@ -32,8 +33,6 @@ def patch_contract(contract: Union[Contract, brownie.Contract, str], w3: Optiona
_patch_if_method(v, w3)
return contract

ContractMethod = Union[ContractCall, ContractTx, OverloadedMethod]

def _patch_if_method(method: ContractMethod, w3: Web3) -> None:
if isinstance(method, (ContractCall, ContractTx)):
method.coroutine = MethodType(_get_coroutine_fn(w3, len(method.abi['inputs'])), method)
Expand Down
6 changes: 4 additions & 2 deletions dank_mids/brownie_patch/overloaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from brownie import Contract
from brownie.network.contract import ContractCall, ContractTx, OverloadedMethod
from dank_mids.brownie_patch.call import _get_coroutine_fn, _skip_proc_pool
from dank_mids.brownie_patch.types import ContractMethod
from web3 import Web3


Expand All @@ -15,7 +16,8 @@ async def coroutine(
self: Contract,
*args: Tuple[Any,...],
block_identifier: Optional[Union[int, str, bytes]] = None,
override: Optional[Dict[str, str]] = None
decimals: Optional[int] = None,
override: Optional[Dict[str, str]] = None,
) -> Any:
try:
fn = self._get_fn_from_args(args)
Expand All @@ -26,7 +28,7 @@ async def coroutine(
raise ValueError(f"{exc_str[:breakpoint]}.coroutine{exc_str[breakpoint:]}")
raise e

kwargs = {"block_identifier": block_identifier, "override": override}
kwargs = {"block_identifier": block_identifier, "decimals": decimals, "override": override}
kwargs = {k: v for k, v in kwargs.items() if v is not None}
return await fn.coroutine(*args, **kwargs)

Expand Down
6 changes: 6 additions & 0 deletions dank_mids/brownie_patch/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

from typing import Union

from brownie.network.contract import ContractCall, ContractTx, OverloadedMethod

ContractMethod = Union[ContractCall, ContractTx, OverloadedMethod]

0 comments on commit b32e48d

Please sign in to comment.