Skip to content

Commit

Permalink
feat: WeakList implementation to reduce weakrefs overhead (#309)
Browse files Browse the repository at this point in the history
* feat: WeakList implementation to reduce weakrefs overhead

* chore: `black .`

* fix: batch not subscriptable anymore

* fix: batch not subscriptable anymore

* fix: batch not subscriptable anymore

* feat: done property

* chore: bump version to 4.20.107

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
BobTheBuidler and github-actions[bot] authored Nov 28, 2024
1 parent 5204df3 commit 6a52b18
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 47 deletions.
12 changes: 6 additions & 6 deletions dank_mids/_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ def coroutines(self) -> Generator[Union["_Batch", Awaitable[RawResponse]], None,
check_len = min(CHECK, self.controller.batcher.step)
# Go thru the multicalls and add calls to the batch
for mcall in self.multicalls.values():
# NOTE: If a multicall has less than `CHECK` calls, we should just throw the calls into a jsonrpc batch individually.
try: # NOTE: This should be faster than using len().
mcall[check_len]
if len(mcall) >= check_len:
working_batch.append(mcall, skip_check=True)
except IndexError:
else:
# NOTE: If a multicall has less than `check_len` calls, we should
# just throw the calls into a jsonrpc batch individually.
working_batch.extend(mcall, skip_check=True)
if working_batch.is_full:
yield working_batch
Expand All @@ -143,8 +143,8 @@ def coroutines(self) -> Generator[Union["_Batch", Awaitable[RawResponse]], None,
working_batch.append(rpc_calls_to_batch.pop(), skip_check=True)
if working_batch:
if working_batch.is_single_multicall:
yield working_batch[0] # type: ignore [misc]
yield next(iter(working_batch)) # type: ignore [misc]
elif len(working_batch) == 1:
yield working_batch[0].make_request() # type: ignore [union-attr]
yield next(iter(working_batch)).make_request() # type: ignore [union-attr]
else:
yield working_batch
109 changes: 72 additions & 37 deletions dank_mids/_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,13 @@ def __await__(self) -> Generator[Any, None, _Response]:
def __len__(self) -> int:
pass

@property
def done(self) -> bool:
return self._done.is_set()

@property
def response(self) -> _Response:
if not self._done.is_set():
if not self.done:
raise ResponseNotReady(self)
return self._response

Expand All @@ -133,9 +137,9 @@ def start(self, batch: Union["_Batch", "DankBatch", None] = None) -> None:

async def _debug_daemon(self) -> None:
# NOTE: _resonse works for RPCRequst and eth_call, _done works for _Batch classes
while self and self._response is None and not self._done.is_set():
while self and self._response is None and not self.done:
await asyncio.sleep(60)
if not self._done.is_set():
if not self.done:
logger.debug(f"{self} has not received data after {time.time() - self._start}s")


Expand Down Expand Up @@ -495,40 +499,73 @@ def semaphore(self) -> a_sync.Semaphore:
_Request = TypeVar("_Request", bound=_RequestMeta)


class WeakRequestList(Generic[_Request]):
def __init__(self, data=None):
self._refs = {} # Mapping from object ID to weak reference
if data is not None:
for item in data:
self.append(item)

def _gc_callback(self, item: _Request) -> None:
# Callback when a weakly-referenced object is garbage collected
self._refs.pop(id(item), None) # Safely remove the item if it exists

def append(self, item: _Request) -> None:
# Keep a weak reference with a callback for when the item is collected
ref = weakref.ref(item, self._gc_callback)
self._refs[id(item)] = ref

def extend(self, items: Iterable[_Request]) -> None:
for item in items:
self.append(item)

def __len__(self) -> int:
return len(self._refs)

def __bool__(self) -> bool:
return bool(self._refs)

def remove(self, item: _Request) -> None:
obj_id = id(item)
ref = self._refs.get(obj_id)
if ref is not None and ref() is item:
del self._refs[obj_id]
else:
raise ValueError("list.remove(x): x not in list")

def __contains__(self, item: _Request) -> bool:
ref = self._refs.get(id(item))
return ref is not None and ref() is item

def __iter__(self) -> Iterator[_Request]:
for ref in self._refs.values():
item = ref()
if item is not None:
yield item

def __repr__(self):
# Use list comprehension syntax within the repr function for clarity
return f"WeakList([{', '.join(repr(item) for item in self)}])"


class _Batch(_RequestMeta[List[_Response]], Iterable[_Request]):
__slots__ = "_calls", "_lock", "_daemon"
_calls: List["weakref.ref[_Request]"]
__slots__ = "calls", "_lock", "_daemon"
calls: WeakRequestList[_Request]

def __init__(self, controller: "DankMiddlewareController", calls: Iterable[_Request]):
self.controller = controller
self._calls = [weakref.ref(call) for call in calls]
self.calls = WeakRequestList(calls)
self._lock = _AlertingRLock(name=self.__class__.__name__)
super().__init__()

def __bool__(self) -> bool:
try:
next(self.calls)
return True
except StopIteration:
return False

@overload
def __getitem__(self, ix: int) -> _Request: ...
@overload
def __getitem__(self, ix: slice) -> Tuple[_Request, ...]: ...
def __getitem__(self, ix: Union[int, slice]) -> Union[_Request, Tuple[_Request, ...]]:
return tuple(self.calls)[ix]
return bool(self.calls)

def __iter__(self) -> Iterator[_Request]:
return self.calls
return iter(self.calls)

def __len__(self) -> int:
return sum(1 for _ in self.calls)

@property
def calls(self) -> Iterator[_Request]:
"Returns a list of calls. Creates a temporary strong reference to each call in the batch, if it still exists."
return (call for ref in self._calls if (call := ref()))
return len(self.calls)

@property
def bisected(self) -> Generator[Tuple[_Request, ...], None, None]:
Expand All @@ -544,7 +581,7 @@ def is_full(self) -> bool:

def append(self, call: _Request, skip_check: bool = False) -> None:
with self._lock:
self._calls.append(weakref.ref(call))
self.calls.append(call)
# self._len += 1
if not skip_check:
if self.is_full:
Expand All @@ -554,7 +591,7 @@ def append(self, call: _Request, skip_check: bool = False) -> None:

def extend(self, calls: Iterable[_Request], skip_check: bool = False) -> None:
with self._lock:
self._calls.extend(weakref.ref(call) for call in calls)
self.calls.extend(calls)
if not skip_check:
if self.is_full:
self.start()
Expand Down Expand Up @@ -642,7 +679,7 @@ def __repr__(self) -> str:

@cached_property
def block(self) -> BlockId:
return next(self.calls).block
return next(iter(self.calls)).block

@property
def calldata(self) -> str:
Expand Down Expand Up @@ -759,20 +796,18 @@ async def spoof_response(

async def decode(self, data: PartialResponse) -> List[Tuple[bool, bytes]]:
start = time.time()
if ENVS.OPERATION_MODE.infura: # type: ignore [attr-defined]
if ENVS.OPERATION_MODE.infura or len(self) < 100:
# decode synchronously
retval = mcall_decode(data)
else:
try: # NOTE: Quickly check for length without counting each item with `len`.
if not ENVS.OPERATION_MODE.application: # type: ignore [attr-defined]
self[100]
try:
retval = await ENVS.MULTICALL_DECODER_PROCESSES.run(mcall_decode, data) # type: ignore [attr-defined]
except IndexError:
retval = mcall_decode(data)
except BrokenProcessPool:
# TODO: Move this somewhere else
logger.critical("Oh fuck, you broke the %s while decoding %s", ENVS.MULTICALL_DECODER_PROCESSES, data) # type: ignore [attr-defined]
ENVS.MULTICALL_DECODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.MULTICALL_DECODER_PROCESSES._max_workers) # type: ignore [attr-defined,assignment]
retval = mcall_decode(data)

stats.log_duration(f"multicall decoding for {len(self)} calls", start)
# Raise any Exceptions that may have come out of the process pool.
if isinstance(retval, Exception):
Expand Down Expand Up @@ -807,7 +842,7 @@ async def bisect_and_retry(self, e: Exception) -> List[RPCResponse]:

@set_done
async def _exec_single_call(self) -> None:
await next(self.calls).make_request()
await next(iter(self.calls)).make_request()

def _post_future_cleanup(self) -> None:
with suppress(KeyError):
Expand Down Expand Up @@ -1108,8 +1143,8 @@ async def bisect_and_retry(self, e: Exception) -> None:
logger.debug("%s had exception %s, retrying", self, e)
batches = [
(
Multicall(self.controller, tuple(chunk[0].calls), f"json{self.jid}_{i}") # type: ignore [misc]
if len(chunk) == 1 and isinstance(chunk[0], Multicall)
Multicall(self.controller, first_call, f"json{self.jid}_{i}") # type: ignore [misc]
if len(chunk) == 1 and isinstance(first_call := next(iter(chunk)), Multicall)
else JSONRPCBatch(self.controller, chunk, f"{self.jid}_{i}")
)
for i, chunk in enumerate(self.bisected)
Expand Down
5 changes: 2 additions & 3 deletions dank_mids/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,9 @@ async def execute_batch(self) -> None:
and executes them as a single batch.
"""
with self.pools_closed_lock: # Do we really need this? # NOTE: yes we do
multicalls = dict(self.pending_eth_calls)
multicalls = self.pending_eth_calls.copy()
self.pending_eth_calls.clear()
self.num_pending_eth_calls = 0
rpc_calls = self.pending_rpc_calls[:]
rpc_calls = self.pending_rpc_calls
self.pending_rpc_calls = JSONRPCBatch(self)
demo_logger.info(f"executing dank batch (current cid: {self.call_uid.latest})") # type: ignore
batch = DankBatch(self, multicalls, rpc_calls)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "dank-mids"
version = "4.20.106"
version = "4.20.107"
description = "Multicall batching middleware for asynchronous scripts using web3.py"
authors = ["BobTheBuidler <bobthebuidlerdefi@gmail.com>"]
homepage = "https://github.com/BobTheBuidler/dank_mids"
Expand Down

0 comments on commit 6a52b18

Please sign in to comment.