From 6a52b189d872ed6c07b7b9768948849b749cdfb3 Mon Sep 17 00:00:00 2001 From: BobTheBuidler <70677534+BobTheBuidler@users.noreply.github.com> Date: Thu, 28 Nov 2024 00:20:53 -0400 Subject: [PATCH] feat: WeakList implementation to reduce weakrefs overhead (#309) * feat: WeakList implementation to reduce weakrefs overhead * chore: `black .` * fix: batch not subscriptable anymore * fix: batch not subscriptable anymore * fix: batch not subscriptable anymore * feat: done property * chore: bump version to 4.20.107 --------- Co-authored-by: github-actions[bot] --- dank_mids/_batch.py | 12 ++--- dank_mids/_requests.py | 109 ++++++++++++++++++++++++++-------------- dank_mids/controller.py | 5 +- pyproject.toml | 2 +- 4 files changed, 81 insertions(+), 47 deletions(-) diff --git a/dank_mids/_batch.py b/dank_mids/_batch.py index e6fd85f9..6db9a728 100644 --- a/dank_mids/_batch.py +++ b/dank_mids/_batch.py @@ -125,11 +125,11 @@ def coroutines(self) -> Generator[Union["_Batch", Awaitable[RawResponse]], None, check_len = min(CHECK, self.controller.batcher.step) # Go thru the multicalls and add calls to the batch for mcall in self.multicalls.values(): - # NOTE: If a multicall has less than `CHECK` calls, we should just throw the calls into a jsonrpc batch individually. - try: # NOTE: This should be faster than using len(). - mcall[check_len] + if len(mcall) >= check_len: working_batch.append(mcall, skip_check=True) - except IndexError: + else: + # NOTE: If a multicall has less than `check_len` calls, we should + # just throw the calls into a jsonrpc batch individually. working_batch.extend(mcall, skip_check=True) if working_batch.is_full: yield working_batch @@ -143,8 +143,8 @@ def coroutines(self) -> Generator[Union["_Batch", Awaitable[RawResponse]], None, working_batch.append(rpc_calls_to_batch.pop(), skip_check=True) if working_batch: if working_batch.is_single_multicall: - yield working_batch[0] # type: ignore [misc] + yield next(iter(working_batch)) # type: ignore [misc] elif len(working_batch) == 1: - yield working_batch[0].make_request() # type: ignore [union-attr] + yield next(iter(working_batch)).make_request() # type: ignore [union-attr] else: yield working_batch diff --git a/dank_mids/_requests.py b/dank_mids/_requests.py index bf154114..c8fba5be 100644 --- a/dank_mids/_requests.py +++ b/dank_mids/_requests.py @@ -117,9 +117,13 @@ def __await__(self) -> Generator[Any, None, _Response]: def __len__(self) -> int: pass + @property + def done(self) -> bool: + return self._done.is_set() + @property def response(self) -> _Response: - if not self._done.is_set(): + if not self.done: raise ResponseNotReady(self) return self._response @@ -133,9 +137,9 @@ def start(self, batch: Union["_Batch", "DankBatch", None] = None) -> None: async def _debug_daemon(self) -> None: # NOTE: _resonse works for RPCRequst and eth_call, _done works for _Batch classes - while self and self._response is None and not self._done.is_set(): + while self and self._response is None and not self.done: await asyncio.sleep(60) - if not self._done.is_set(): + if not self.done: logger.debug(f"{self} has not received data after {time.time() - self._start}s") @@ -495,40 +499,73 @@ def semaphore(self) -> a_sync.Semaphore: _Request = TypeVar("_Request", bound=_RequestMeta) +class WeakRequestList(Generic[_Request]): + def __init__(self, data=None): + self._refs = {} # Mapping from object ID to weak reference + if data is not None: + for item in data: + self.append(item) + + def _gc_callback(self, item: _Request) -> None: + # Callback when a weakly-referenced object is garbage collected + self._refs.pop(id(item), None) # Safely remove the item if it exists + + def append(self, item: _Request) -> None: + # Keep a weak reference with a callback for when the item is collected + ref = weakref.ref(item, self._gc_callback) + self._refs[id(item)] = ref + + def extend(self, items: Iterable[_Request]) -> None: + for item in items: + self.append(item) + + def __len__(self) -> int: + return len(self._refs) + + def __bool__(self) -> bool: + return bool(self._refs) + + def remove(self, item: _Request) -> None: + obj_id = id(item) + ref = self._refs.get(obj_id) + if ref is not None and ref() is item: + del self._refs[obj_id] + else: + raise ValueError("list.remove(x): x not in list") + + def __contains__(self, item: _Request) -> bool: + ref = self._refs.get(id(item)) + return ref is not None and ref() is item + + def __iter__(self) -> Iterator[_Request]: + for ref in self._refs.values(): + item = ref() + if item is not None: + yield item + + def __repr__(self): + # Use list comprehension syntax within the repr function for clarity + return f"WeakList([{', '.join(repr(item) for item in self)}])" + + class _Batch(_RequestMeta[List[_Response]], Iterable[_Request]): - __slots__ = "_calls", "_lock", "_daemon" - _calls: List["weakref.ref[_Request]"] + __slots__ = "calls", "_lock", "_daemon" + calls: WeakRequestList[_Request] def __init__(self, controller: "DankMiddlewareController", calls: Iterable[_Request]): self.controller = controller - self._calls = [weakref.ref(call) for call in calls] + self.calls = WeakRequestList(calls) self._lock = _AlertingRLock(name=self.__class__.__name__) super().__init__() def __bool__(self) -> bool: - try: - next(self.calls) - return True - except StopIteration: - return False - - @overload - def __getitem__(self, ix: int) -> _Request: ... - @overload - def __getitem__(self, ix: slice) -> Tuple[_Request, ...]: ... - def __getitem__(self, ix: Union[int, slice]) -> Union[_Request, Tuple[_Request, ...]]: - return tuple(self.calls)[ix] + return bool(self.calls) def __iter__(self) -> Iterator[_Request]: - return self.calls + return iter(self.calls) def __len__(self) -> int: - return sum(1 for _ in self.calls) - - @property - def calls(self) -> Iterator[_Request]: - "Returns a list of calls. Creates a temporary strong reference to each call in the batch, if it still exists." - return (call for ref in self._calls if (call := ref())) + return len(self.calls) @property def bisected(self) -> Generator[Tuple[_Request, ...], None, None]: @@ -544,7 +581,7 @@ def is_full(self) -> bool: def append(self, call: _Request, skip_check: bool = False) -> None: with self._lock: - self._calls.append(weakref.ref(call)) + self.calls.append(call) # self._len += 1 if not skip_check: if self.is_full: @@ -554,7 +591,7 @@ def append(self, call: _Request, skip_check: bool = False) -> None: def extend(self, calls: Iterable[_Request], skip_check: bool = False) -> None: with self._lock: - self._calls.extend(weakref.ref(call) for call in calls) + self.calls.extend(calls) if not skip_check: if self.is_full: self.start() @@ -642,7 +679,7 @@ def __repr__(self) -> str: @cached_property def block(self) -> BlockId: - return next(self.calls).block + return next(iter(self.calls)).block @property def calldata(self) -> str: @@ -759,20 +796,18 @@ async def spoof_response( async def decode(self, data: PartialResponse) -> List[Tuple[bool, bytes]]: start = time.time() - if ENVS.OPERATION_MODE.infura: # type: ignore [attr-defined] + if ENVS.OPERATION_MODE.infura or len(self) < 100: + # decode synchronously retval = mcall_decode(data) else: - try: # NOTE: Quickly check for length without counting each item with `len`. - if not ENVS.OPERATION_MODE.application: # type: ignore [attr-defined] - self[100] + try: retval = await ENVS.MULTICALL_DECODER_PROCESSES.run(mcall_decode, data) # type: ignore [attr-defined] - except IndexError: - retval = mcall_decode(data) except BrokenProcessPool: # TODO: Move this somewhere else logger.critical("Oh fuck, you broke the %s while decoding %s", ENVS.MULTICALL_DECODER_PROCESSES, data) # type: ignore [attr-defined] ENVS.MULTICALL_DECODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.MULTICALL_DECODER_PROCESSES._max_workers) # type: ignore [attr-defined,assignment] retval = mcall_decode(data) + stats.log_duration(f"multicall decoding for {len(self)} calls", start) # Raise any Exceptions that may have come out of the process pool. if isinstance(retval, Exception): @@ -807,7 +842,7 @@ async def bisect_and_retry(self, e: Exception) -> List[RPCResponse]: @set_done async def _exec_single_call(self) -> None: - await next(self.calls).make_request() + await next(iter(self.calls)).make_request() def _post_future_cleanup(self) -> None: with suppress(KeyError): @@ -1108,8 +1143,8 @@ async def bisect_and_retry(self, e: Exception) -> None: logger.debug("%s had exception %s, retrying", self, e) batches = [ ( - Multicall(self.controller, tuple(chunk[0].calls), f"json{self.jid}_{i}") # type: ignore [misc] - if len(chunk) == 1 and isinstance(chunk[0], Multicall) + Multicall(self.controller, first_call, f"json{self.jid}_{i}") # type: ignore [misc] + if len(chunk) == 1 and isinstance(first_call := next(iter(chunk)), Multicall) else JSONRPCBatch(self.controller, chunk, f"{self.jid}_{i}") ) for i, chunk in enumerate(self.bisected) diff --git a/dank_mids/controller.py b/dank_mids/controller.py index da183ea2..3f567235 100644 --- a/dank_mids/controller.py +++ b/dank_mids/controller.py @@ -264,10 +264,9 @@ async def execute_batch(self) -> None: and executes them as a single batch. """ with self.pools_closed_lock: # Do we really need this? # NOTE: yes we do - multicalls = dict(self.pending_eth_calls) + multicalls = self.pending_eth_calls.copy() self.pending_eth_calls.clear() - self.num_pending_eth_calls = 0 - rpc_calls = self.pending_rpc_calls[:] + rpc_calls = self.pending_rpc_calls self.pending_rpc_calls = JSONRPCBatch(self) demo_logger.info(f"executing dank batch (current cid: {self.call_uid.latest})") # type: ignore batch = DankBatch(self, multicalls, rpc_calls) diff --git a/pyproject.toml b/pyproject.toml index 1e077e57..d13a7110 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dank-mids" -version = "4.20.106" +version = "4.20.107" description = "Multicall batching middleware for asynchronous scripts using web3.py" authors = ["BobTheBuidler "] homepage = "https://github.com/BobTheBuidler/dank_mids"