Skip to content

Commit

Permalink
add debug flag
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Jan 10, 2025
1 parent 7c0ef2f commit 3d17a96
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 10 deletions.
16 changes: 13 additions & 3 deletions boa/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,24 @@ def fork(
url: str,
reset_traces=True,
block_identifier="safe",
debug=False,
deprecated=True,
**kwargs,
):
if deprecated:
warnings.warn("using boa.env.fork directly is deprecated; use `boa.fork`!")
return self.fork_rpc(EthereumRPC(url), reset_traces, block_identifier, **kwargs)
return self.fork_rpc(
EthereumRPC(url), reset_traces, block_identifier, debug, **kwargs
)

def fork_rpc(self, rpc: RPC, reset_traces=True, block_identifier="safe", **kwargs):
def fork_rpc(
self,
rpc: RPC,
reset_traces=True,
block_identifier="safe",
debug=False,
**kwargs,
):
"""
Fork the environment to a local chain.
:param rpc: RPC to fork from
Expand All @@ -82,7 +92,7 @@ def fork_rpc(self, rpc: RPC, reset_traces=True, block_identifier="safe", **kwarg
self.sha3_trace = {}
self.sstore_trace = {}

self.evm.fork_rpc(rpc, block_identifier, **kwargs)
self.evm.fork_rpc(rpc, block_identifier, debug=debug, **kwargs)

def get_gas_meter_class(self):
return self.evm.get_gas_meter_class()
Expand Down
36 changes: 31 additions & 5 deletions boa/vm/fork.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pickle
import sys
from pathlib import Path
from typing import Any, Type

Expand Down Expand Up @@ -35,7 +36,7 @@ class CachingRPC(RPC):
_loaded: dict[tuple[str, int, str], "CachingRPC"] = {}
_pid: int = os.getpid() # so we can detect if our fds are bad

def __new__(cls, rpc, chain_id, cache_dir=None):
def __new__(cls, rpc, chain_id, debug, cache_dir=None):
if isinstance(rpc, cls):
if rpc._chain_id == chain_id:
return rpc
Expand All @@ -55,12 +56,15 @@ def __new__(cls, rpc, chain_id, cache_dir=None):
return cls._loaded[(rpc.identifier, chain_id, cache_dir)]

ret = super().__new__(cls)
ret.__init__(rpc, chain_id, cache_dir)
ret.__init__(rpc, chain_id, debug, cache_dir)
cls._loaded[(rpc.identifier, chain_id, cache_dir)] = ret
return ret

def __init__(self, rpc: RPC, chain_id: int, cache_dir: str = None):
def __init__(
self, rpc: RPC, chain_id: int, debug: bool = False, cache_dir: str = None
):
self._rpc = rpc
self._debug = debug

self._chain_id = chain_id # TODO: check if this is needed

Expand Down Expand Up @@ -97,14 +101,29 @@ def name(self):
def _mk_key(self, method: str, params: Any) -> Any:
return pickle.dumps((method, params))

_col_limit = 97

def _debug_dump(self, item):
str_item = str(item)
# TODO: make this configurable
if len(str_item) > self._col_limit:
return str_item[: self._col_limit] + "..."
return str_item

def fetch(self, method, params):
# cannot dispatch into fetch_multi, doesn't work for debug_traceCall.
key = self._mk_key(method, params)
if self._debug:
print(method, self._debug_dump(params), file=sys.stderr)
if key in self._db:
ret = pickle.loads(self._db[key])
if self._debug:
print("(hit)", self._debug_dump(ret), file=sys.stderr)
return ret

result = self._rpc.fetch(method, params)
if self._debug:
print("(miss)", self._debug_dump(result), file=sys.stderr)
self._db[key] = pickle.dumps(result)
return result

Expand All @@ -120,6 +139,9 @@ def fetch_multi(self, payload):
key = self._mk_key(method, params)
try:
ret[item_ix] = pickle.loads(self._db[key])
if self._debug:
print(method, self._debug_dump(params), file=sys.stderr)
print("(hit)", self._debug_dump(ret[item_ix]), file=sys.stderr)
except KeyError:
keys.append((key, item_ix))
batch.append((method, params))
Expand All @@ -130,6 +152,10 @@ def fetch_multi(self, payload):
for result_ix, rpc_result in enumerate(self._rpc.fetch_multi(batch)):
key, item_ix = keys[result_ix]
ret[item_ix] = rpc_result
if self._debug:
params = batch[item_ix][1]
print(method, self._debug_dump(params), file=sys.stderr)
print("(miss)", self._debug_dump(rpc_result), file=sys.stderr)
self._db[key] = pickle.dumps(rpc_result)

return [ret[i] for i in range(len(ret))]
Expand All @@ -140,12 +166,12 @@ def fetch_multi(self, payload):
class AccountDBFork(AccountDB):
@classmethod
def class_from_rpc(
cls, rpc: RPC, block_identifier: str, **kwargs
cls, rpc: RPC, block_identifier: str, debug: bool, **kwargs
) -> Type["AccountDBFork"]:
class _ConfiguredAccountDB(AccountDBFork):
def __init__(self, *args, **kwargs2):
chain_id = int(rpc.fetch_uncached("eth_chainId", []), 16)
caching_rpc = CachingRPC(rpc, chain_id, **kwargs)
caching_rpc = CachingRPC(rpc, chain_id, debug, **kwargs)
super().__init__(
caching_rpc, chain_id, block_identifier, *args, **kwargs2
)
Expand Down
6 changes: 4 additions & 2 deletions boa/vm/py_evm.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,10 @@ def enable_fast_mode(self, flag: bool = True):
else:
unpatch_pyevm_state_object(self.vm.state)

def fork_rpc(self, rpc: RPC, block_identifier: str, **kwargs):
account_db_class = AccountDBFork.class_from_rpc(rpc, block_identifier, **kwargs)
def fork_rpc(self, rpc: RPC, block_identifier: str, debug: bool, **kwargs):
account_db_class = AccountDBFork.class_from_rpc(
rpc, block_identifier, debug, **kwargs
)
self._init_vm(account_db_class)

block_info = self.vm.state._account_db._block_info
Expand Down

0 comments on commit 3d17a96

Please sign in to comment.