From 7c0235afb4c434b7e4003451abe92ce0240677f2 Mon Sep 17 00:00:00 2001 From: BobTheBuidler <70677534+BobTheBuidler@users.noreply.github.com> Date: Fri, 3 May 2024 19:15:54 -0400 Subject: [PATCH] feat: use weakrefs to enable call cancellation (#199) --- dank_mids/_requests.py | 69 +++++++++++++++++++++-------------- dank_mids/controller.py | 38 +++++++++++-------- dank_mids/helpers/_session.py | 2 +- dank_mids/semaphores.py | 4 +- 4 files changed, 68 insertions(+), 45 deletions(-) diff --git a/dank_mids/_requests.py b/dank_mids/_requests.py index 62358d94..1e35bcdc 100644 --- a/dank_mids/_requests.py +++ b/dank_mids/_requests.py @@ -94,7 +94,8 @@ def start(self, batch: Union["_Batch", "DankBatch", None] = None) -> None: pass async def _debug_daemon(self) -> None: - while not self._done.is_set(): + # NOTE: _resonse works for RPCRequst and eth_call, _done works for _Batch classes + while self and self._response is None and not self._done.is_set(): await asyncio.sleep(60) if not self._done.is_set(): logger.debug(f"{self} has not received data after {time.time() - self._start}s") @@ -146,7 +147,7 @@ def __len__(self) -> int: return 1 def __repr__(self) -> str: - return f"<{self.__class__.__name__} uid={self.uid} method={self.method}>" + return f"<{self.__class__.__name__} uid={self.uid} method={self.method} params={self.params}>" @property def request(self) -> Union[Request, PartialRequest]: @@ -279,7 +280,7 @@ async def create_duplicate(self) -> RPCResponse: # Not actually self, but for ty # We need to check the semaphore again to ensure we have the right context manager, soon but not right away. # Creating the task before awaiting the new call ensures the new call will grab the semaphore immediately # and then the task will try to acquire at the very next event loop _run_once cycle - logger.warning(f"call {self.uid} got stuck, we're creating a new one") + logger.warning("%s got stuck, we're creating a new one", self) retval = await self.controller(self.method, self.params) await self.semaphore.acquire() return retval @@ -312,6 +313,9 @@ def __init__(self, controller: "DankMiddlewareController", params: Any, retry: b """The block height at which the contract will be called.""" super().__init__(controller, "eth_call", params) # type: ignore + def __repr__(self) -> str: + return f"<{self.__class__.__name__} uid={self.uid} params={self.params}>" + @property def calldata(self) -> HexBytes: """The calldata for the call.""" @@ -361,11 +365,11 @@ def semaphore(self) -> a_sync.Semaphore: _Request = TypeVar("_Request", bound=_RequestMeta) class _Batch(_RequestMeta[List[_Response]], Iterable[_Request]): - __slots__ = 'calls', '_lock', '_daemon' - calls: List[_Request] + __slots__ = '_calls', '_lock', '_daemon' + _calls: List["weakref.ref[_Request]"] def __init__(self, controller: "DankMiddlewareController", calls: Iterable[_Request]): self.controller = controller - self.calls = [weakref.proxy(call, callback=self._remove) for call in calls] + self._calls = [weakref.ref(call) for call in calls] self._lock = _AlertingRLock(name=self.__class__.__name__) super().__init__() @@ -385,6 +389,11 @@ def __iter__(self) -> Iterator[_Request]: def __len__(self) -> int: return len(self.calls) + @property + def calls(self) -> List[_Request]: + "Returns a list of calls. Creates a temporary strong reference to each call in the batch, if it still exists." + return [call for ref in self._calls if (call := ref())] + @property def halfpoint(self) -> int: return len(self) // 2 @@ -408,7 +417,7 @@ def is_full(self) -> bool: def append(self, call: _Request, skip_check: bool = False) -> None: with self._lock: - self.calls.append(call) + self._calls.append(weakref.ref(call)) #self._len += 1 if not skip_check: if self.is_full: @@ -418,8 +427,7 @@ def append(self, call: _Request, skip_check: bool = False) -> None: def extend(self, calls: Iterable[_Request], skip_check: bool = False) -> None: with self._lock: - self.calls.extend(calls) - #self._len += len(calls) + self._calls.extend(weakref.ref(call) for call in calls) if not skip_check: if self.is_full: self.start() @@ -533,7 +541,9 @@ async def get_response(self) -> None: # type: ignore [override] rid = self.controller.request_uid.next demo_logger.info(f'request {rid} for multicall {self.bid} starting') # type: ignore try: - await self.spoof_response(await self.controller.make_request(self.method, self.params, request_id=self.uid)) + # create a strong ref to all calls we will execute so they cant get gced mid execution and mess up response ordering + calls = self.calls + await self.spoof_response(await self.controller.make_request(self.method, self.params, request_id=self.uid), calls) except internal_err_types.__args__ as e: # type: ignore [attr-defined] raise e if 'invalid argument' in str(e) else DankMidsInternalError(e) from e except ClientResponseError as e: @@ -564,12 +574,15 @@ def should_retry(self, e: Exception) -> bool: return len(self) > 1 @set_done - async def spoof_response(self, data: Union[RawResponse, Exception]) -> None: + async def spoof_response(self, data: Union[RawResponse, Exception], calls: Optional[List[eth_call]] = None) -> None: + # NOTE: we pass in the calls to create a strong reference so when we zip up the results everything gets to the right place + if calls is None: + calls = self.calls # This happens if an Exception takes place during a singular Multicall request. if isinstance(data, Exception): logger.debug("%s had Exception %s", self, data) logger.debug("propagating the %s to all %s's calls", data.__class__.__name__, self) - await asyncio.gather(*[call.spoof_response(data) for call in self.calls]) + await asyncio.gather(*[call.spoof_response(data) for call in calls]) # A `RawResponse` represents either a successful or a failed response, stored as pre-decoded bytes. # It was either received as a response to a single rpc call or as a part of a batch response. elif isinstance(data, RawResponse): @@ -579,7 +592,7 @@ async def spoof_response(self, data: Union[RawResponse, Exception]) -> None: # NOTE: We raise the exception which will be caught, call will be broken up and retried raise response.exception logger.debug("%s received valid bytes from the rpc", self) - await asyncio.gather(*(call.spoof_response(data) for call, (_, data) in zip(self.calls, await self.decode(response)))) + await asyncio.gather(*(call.spoof_response(data) for call, (_, data) in zip(calls, await self.decode(response)))) else: raise NotImplementedError(f"type {type(data)} not supported.", data) @@ -649,13 +662,13 @@ def __repr__(self) -> str: @property def data(self) -> bytes: - if not self.calls: + if not self: raise EmptyBatch(f"batch {self.uid} is empty and should not be processed.") try: - return _codec.encode([call.request for call in self.calls]) + return _codec.encode([call.request for call in self]) except TypeError: # If we can't encode one of the calls, lets figure out which one and pass some useful info downstream - for call in self.calls: + for call in self: try: _codec.encode(call.request) except TypeError as e: @@ -665,7 +678,7 @@ def data(self) -> bytes: @property def is_multicalls_only(self) -> bool: with self._lock: - return all(isinstance(call, Multicall) for call in self.calls) + return all(isinstance(call, Multicall) for call in self) @property def is_single_multicall(self) -> bool: @@ -676,7 +689,7 @@ def is_single_multicall(self) -> bool: def method_counts(self) -> Dict[RPCEndpoint, int]: counts: DefaultDict[RPCEndpoint, int] = defaultdict(int) with self._lock: - for call in self.calls: + for call in self: if isinstance(call, Multicall): counts["eth_call[multicall]"] += len(call) # type: ignore else: @@ -686,7 +699,7 @@ def method_counts(self) -> Dict[RPCEndpoint, int]: @property def total_calls(self) -> int: with self._lock: - return sum(len(call) for call in self.calls) + return sum(len(call) for call in self) @property def is_full(self) -> bool: @@ -701,10 +714,10 @@ async def get_response(self) -> None: # type: ignore [override] rid = self.controller.request_uid.next if ENVS.DEMO_MODE: # type: ignore [attr-defined] # When demo mode is disabled, we can save some CPU time by skipping this sum - demo_logger.info(f'request {rid} for jsonrpc batch {self.jid} ({sum(len(batch) for batch in self.calls)} calls) starting') # type: ignore + demo_logger.info(f'request {rid} for jsonrpc batch {self.jid} ({sum(len(batch) for batch in self)} calls) starting') # type: ignore try: # NOTE: We do this inline so we never have to allocate the response to memory - await self.spoof_response(await self.post()) + await self.spoof_response(*await self.post()) # I want to see these asap when working on the lib. except internal_err_types.__args__ as e: # type: ignore [attr-defined] raise e if 'invalid argument' in str(e) else DankMidsInternalError(e) from e @@ -739,9 +752,13 @@ async def get_response(self) -> None: # type: ignore [override] demo_logger.info(f'request {rid} for jsonrpc batch {self.jid} complete') # type: ignore @eth_retry.auto_retry - async def post(self) -> List[RawResponse]: + async def post(self) -> Tuple[List[RawResponse], List[Union[Multicall, RPCRequest]]]: "this function raises `BadResponse` if a successful 'error' response was received from the rpc" try: + # we need strong refs so the results all get to the right place + calls = self.calls + # for the multicalls too + mcall_calls_strong_refs = [call.calls for call in filter(lambda x: isinstance(x, Multicall), calls)] # type: ignore [union-attr] response: JSONRPCBatchResponse = await _session.post(self.controller.endpoint, data=self.data, loads=_codec.decode_jsonrpc_batch) except ClientResponseError as e: if e.message == "Payload Too Large": @@ -763,7 +780,7 @@ async def post(self) -> List[RawResponse]: # NOTE: A successful response will be a list of `RawResponse` objects. # A single `PartialResponse` implies an error. if isinstance(response, list): - return response + return response, calls # Oops, we failed. if response.error.message.lower() in ['invalid request', 'parse error']: # type: ignore [union-attr] # NOT SURE IF THIS ACTUALLY RUNS, CAN WE RECEIVE THIS TYPE RESPONSE FOR A JSON BATCH? @@ -790,19 +807,17 @@ def should_retry(self, e: Exception) -> bool: return self.is_single_multicall @set_done - async def spoof_response(self, response: List[RawResponse]) -> None: + async def spoof_response(self, response: List[RawResponse], calls: List[RPCRequest]) -> None: # This means we got results. That doesn't mean they're good, but we got 'em. if self.controller._sort_calls: # NOTE: these providers don't always return batch results in the correct ordering # NOTE: is it maybe because they - calls = sorted(self.calls, key=lambda call: call.uid) + calls = sorted(calls, key=lambda call: call.uid) for i, (call, raw) in enumerate(zip(calls, response)): # TODO: make sure this doesn't ever raise and then delete it decoded = raw.decode() assert call.uid == decoded.id, (i, call, decoded, response, [[call.uid for call in calls], [raw.decode() for raw in response]]) - else: - calls = self.calls for r in await asyncio.gather(*[call.spoof_response(raw) for call, raw in zip(calls, response)], return_exceptions=True): # NOTE: By doing this with the exceptions we allow any successful calls to get their results sooner diff --git a/dank_mids/controller.py b/dank_mids/controller.py index 35951432..71f447e1 100644 --- a/dank_mids/controller.py +++ b/dank_mids/controller.py @@ -128,24 +128,32 @@ def __repr__(self) -> str: return f"" async def __call__(self, method: RPCEndpoint, params: Any) -> RPCResponse: - with suppress(KeyError): - # some methods go thru a SmartProcessingQueue, we try this first - try: - queue = self.method_queues[method] - return await queue(self, method, params) - except TypeError as e: - if "unhashable type" in str(e): - return await queue(self, method, _helpers._make_hashable(params)) - raise e - + # eth_call go thru a specialized Semaphore and other methods pass thru unblocked - logger.debug(f'making {self.request_type.__name__} {method} with params {params}') - if method != "eth_call": - return await RPCRequest(self, method, params) - async with self.eth_call_semaphores[params[1]]: - if params[0]["to"] not in self.no_multicall: + if method == "eth_call": + async with self.eth_call_semaphores[params[1]]: + # create a strong ref to the call that will be held until the caller completes or is cancelled + logger.debug(f'making {self.request_type.__name__} {method} with params {params}') + if params[0]["to"] in self.no_multicall: + return await RPCRequest(self, method, params) return await eth_call(self, params) + + # some methods go thru a SmartProcessingQueue, we check those next + queue = self.method_queues[method] + logger.debug(f'making {self.request_type.__name__} {method} with params {params}') + + # no queue, we can make the request normally + if queue is None: return await RPCRequest(self, method, params) + + # queue found, queue up the call and await the future + try: + # NOTE: is this a strong enough ref? + return await queue(self, method, params) + except TypeError as e: + if "unhashable type" not in str(e): + raise e + return await queue(self, method, _helpers._make_hashable(params)) @eth_retry.auto_retry async def make_request(self, method: str, params: List[Any], request_id: Optional[int] = None) -> RawResponse: diff --git a/dank_mids/helpers/_session.py b/dank_mids/helpers/_session.py index 81e4bf76..bde344c6 100644 --- a/dank_mids/helpers/_session.py +++ b/dank_mids/helpers/_session.py @@ -92,7 +92,7 @@ async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DEC if isinstance(kwargs.get('data'), PartialRequest): logger.debug("making request for %s", kwargs['data']) kwargs['data'] = _codec.encode(kwargs['data']) - logger.debug("making request with (args, kwargs): (%s %s)", tuple(endpoint, *args), kwargs) + logger.debug("making request to %s with (args, kwargs): (%s %s)", endpoint, args, kwargs) # Try the request until success or 5 failures. tried = 0 diff --git a/dank_mids/semaphores.py b/dank_mids/semaphores.py index 2a33b3b9..6b7b9140 100644 --- a/dank_mids/semaphores.py +++ b/dank_mids/semaphores.py @@ -60,8 +60,8 @@ def __init__(self, controller: "DankMiddlewareController") -> None: } self.keys = self.method_queues.keys() @functools.lru_cache(maxsize=None) - def __getitem__(self, method: RPCEndpoint) -> a_sync.SmartProcessingQueue: + def __getitem__(self, method: RPCEndpoint) -> Optional[a_sync.SmartProcessingQueue]: try: return next(self.method_queues[key] for key in self.keys if key in method) except StopIteration: - raise KeyError(method) from None + return None