Skip to content

Commit

Permalink
feat: use weakrefs to enable call cancellation (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored May 3, 2024
1 parent 8a7c0c6 commit 7c0235a
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 45 deletions.
69 changes: 42 additions & 27 deletions dank_mids/_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def start(self, batch: Union["_Batch", "DankBatch", None] = None) -> None:
pass

async def _debug_daemon(self) -> None:
while not self._done.is_set():
# NOTE: _resonse works for RPCRequst and eth_call, _done works for _Batch classes
while self and self._response is None and not self._done.is_set():
await asyncio.sleep(60)
if not self._done.is_set():
logger.debug(f"{self} has not received data after {time.time() - self._start}s")
Expand Down Expand Up @@ -146,7 +147,7 @@ def __len__(self) -> int:
return 1

def __repr__(self) -> str:
return f"<{self.__class__.__name__} uid={self.uid} method={self.method}>"
return f"<{self.__class__.__name__} uid={self.uid} method={self.method} params={self.params}>"

@property
def request(self) -> Union[Request, PartialRequest]:
Expand Down Expand Up @@ -279,7 +280,7 @@ async def create_duplicate(self) -> RPCResponse: # Not actually self, but for ty
# We need to check the semaphore again to ensure we have the right context manager, soon but not right away.
# Creating the task before awaiting the new call ensures the new call will grab the semaphore immediately
# and then the task will try to acquire at the very next event loop _run_once cycle
logger.warning(f"call {self.uid} got stuck, we're creating a new one")
logger.warning("%s got stuck, we're creating a new one", self)
retval = await self.controller(self.method, self.params)
await self.semaphore.acquire()
return retval
Expand Down Expand Up @@ -312,6 +313,9 @@ def __init__(self, controller: "DankMiddlewareController", params: Any, retry: b
"""The block height at which the contract will be called."""
super().__init__(controller, "eth_call", params) # type: ignore

def __repr__(self) -> str:
return f"<{self.__class__.__name__} uid={self.uid} params={self.params}>"

@property
def calldata(self) -> HexBytes:
"""The calldata for the call."""
Expand Down Expand Up @@ -361,11 +365,11 @@ def semaphore(self) -> a_sync.Semaphore:
_Request = TypeVar("_Request", bound=_RequestMeta)

class _Batch(_RequestMeta[List[_Response]], Iterable[_Request]):
__slots__ = 'calls', '_lock', '_daemon'
calls: List[_Request]
__slots__ = '_calls', '_lock', '_daemon'
_calls: List["weakref.ref[_Request]"]
def __init__(self, controller: "DankMiddlewareController", calls: Iterable[_Request]):
self.controller = controller
self.calls = [weakref.proxy(call, callback=self._remove) for call in calls]
self._calls = [weakref.ref(call) for call in calls]
self._lock = _AlertingRLock(name=self.__class__.__name__)
super().__init__()

Expand All @@ -385,6 +389,11 @@ def __iter__(self) -> Iterator[_Request]:
def __len__(self) -> int:
return len(self.calls)

@property
def calls(self) -> List[_Request]:
"Returns a list of calls. Creates a temporary strong reference to each call in the batch, if it still exists."
return [call for ref in self._calls if (call := ref())]

@property
def halfpoint(self) -> int:
return len(self) // 2
Expand All @@ -408,7 +417,7 @@ def is_full(self) -> bool:

def append(self, call: _Request, skip_check: bool = False) -> None:
with self._lock:
self.calls.append(call)
self._calls.append(weakref.ref(call))
#self._len += 1
if not skip_check:
if self.is_full:
Expand All @@ -418,8 +427,7 @@ def append(self, call: _Request, skip_check: bool = False) -> None:

def extend(self, calls: Iterable[_Request], skip_check: bool = False) -> None:
with self._lock:
self.calls.extend(calls)
#self._len += len(calls)
self._calls.extend(weakref.ref(call) for call in calls)
if not skip_check:
if self.is_full:
self.start()
Expand Down Expand Up @@ -533,7 +541,9 @@ async def get_response(self) -> None: # type: ignore [override]
rid = self.controller.request_uid.next
demo_logger.info(f'request {rid} for multicall {self.bid} starting') # type: ignore
try:
await self.spoof_response(await self.controller.make_request(self.method, self.params, request_id=self.uid))
# create a strong ref to all calls we will execute so they cant get gced mid execution and mess up response ordering
calls = self.calls
await self.spoof_response(await self.controller.make_request(self.method, self.params, request_id=self.uid), calls)
except internal_err_types.__args__ as e: # type: ignore [attr-defined]
raise e if 'invalid argument' in str(e) else DankMidsInternalError(e) from e
except ClientResponseError as e:
Expand Down Expand Up @@ -564,12 +574,15 @@ def should_retry(self, e: Exception) -> bool:
return len(self) > 1

@set_done
async def spoof_response(self, data: Union[RawResponse, Exception]) -> None:
async def spoof_response(self, data: Union[RawResponse, Exception], calls: Optional[List[eth_call]] = None) -> None:
# NOTE: we pass in the calls to create a strong reference so when we zip up the results everything gets to the right place
if calls is None:
calls = self.calls
# This happens if an Exception takes place during a singular Multicall request.
if isinstance(data, Exception):
logger.debug("%s had Exception %s", self, data)
logger.debug("propagating the %s to all %s's calls", data.__class__.__name__, self)
await asyncio.gather(*[call.spoof_response(data) for call in self.calls])
await asyncio.gather(*[call.spoof_response(data) for call in calls])
# A `RawResponse` represents either a successful or a failed response, stored as pre-decoded bytes.
# It was either received as a response to a single rpc call or as a part of a batch response.
elif isinstance(data, RawResponse):
Expand All @@ -579,7 +592,7 @@ async def spoof_response(self, data: Union[RawResponse, Exception]) -> None:
# NOTE: We raise the exception which will be caught, call will be broken up and retried
raise response.exception
logger.debug("%s received valid bytes from the rpc", self)
await asyncio.gather(*(call.spoof_response(data) for call, (_, data) in zip(self.calls, await self.decode(response))))
await asyncio.gather(*(call.spoof_response(data) for call, (_, data) in zip(calls, await self.decode(response))))
else:
raise NotImplementedError(f"type {type(data)} not supported.", data)

Expand Down Expand Up @@ -649,13 +662,13 @@ def __repr__(self) -> str:

@property
def data(self) -> bytes:
if not self.calls:
if not self:
raise EmptyBatch(f"batch {self.uid} is empty and should not be processed.")
try:
return _codec.encode([call.request for call in self.calls])
return _codec.encode([call.request for call in self])
except TypeError:
# If we can't encode one of the calls, lets figure out which one and pass some useful info downstream
for call in self.calls:
for call in self:
try:
_codec.encode(call.request)
except TypeError as e:
Expand All @@ -665,7 +678,7 @@ def data(self) -> bytes:
@property
def is_multicalls_only(self) -> bool:
with self._lock:
return all(isinstance(call, Multicall) for call in self.calls)
return all(isinstance(call, Multicall) for call in self)

@property
def is_single_multicall(self) -> bool:
Expand All @@ -676,7 +689,7 @@ def is_single_multicall(self) -> bool:
def method_counts(self) -> Dict[RPCEndpoint, int]:
counts: DefaultDict[RPCEndpoint, int] = defaultdict(int)
with self._lock:
for call in self.calls:
for call in self:
if isinstance(call, Multicall):
counts["eth_call[multicall]"] += len(call) # type: ignore
else:
Expand All @@ -686,7 +699,7 @@ def method_counts(self) -> Dict[RPCEndpoint, int]:
@property
def total_calls(self) -> int:
with self._lock:
return sum(len(call) for call in self.calls)
return sum(len(call) for call in self)

@property
def is_full(self) -> bool:
Expand All @@ -701,10 +714,10 @@ async def get_response(self) -> None: # type: ignore [override]
rid = self.controller.request_uid.next
if ENVS.DEMO_MODE: # type: ignore [attr-defined]
# When demo mode is disabled, we can save some CPU time by skipping this sum
demo_logger.info(f'request {rid} for jsonrpc batch {self.jid} ({sum(len(batch) for batch in self.calls)} calls) starting') # type: ignore
demo_logger.info(f'request {rid} for jsonrpc batch {self.jid} ({sum(len(batch) for batch in self)} calls) starting') # type: ignore
try:
# NOTE: We do this inline so we never have to allocate the response to memory
await self.spoof_response(await self.post())
await self.spoof_response(*await self.post())
# I want to see these asap when working on the lib.
except internal_err_types.__args__ as e: # type: ignore [attr-defined]
raise e if 'invalid argument' in str(e) else DankMidsInternalError(e) from e
Expand Down Expand Up @@ -739,9 +752,13 @@ async def get_response(self) -> None: # type: ignore [override]
demo_logger.info(f'request {rid} for jsonrpc batch {self.jid} complete') # type: ignore

@eth_retry.auto_retry
async def post(self) -> List[RawResponse]:
async def post(self) -> Tuple[List[RawResponse], List[Union[Multicall, RPCRequest]]]:
"this function raises `BadResponse` if a successful 'error' response was received from the rpc"
try:
# we need strong refs so the results all get to the right place
calls = self.calls
# for the multicalls too
mcall_calls_strong_refs = [call.calls for call in filter(lambda x: isinstance(x, Multicall), calls)] # type: ignore [union-attr]
response: JSONRPCBatchResponse = await _session.post(self.controller.endpoint, data=self.data, loads=_codec.decode_jsonrpc_batch)
except ClientResponseError as e:
if e.message == "Payload Too Large":
Expand All @@ -763,7 +780,7 @@ async def post(self) -> List[RawResponse]:
# NOTE: A successful response will be a list of `RawResponse` objects.
# A single `PartialResponse` implies an error.
if isinstance(response, list):
return response
return response, calls
# Oops, we failed.
if response.error.message.lower() in ['invalid request', 'parse error']: # type: ignore [union-attr]
# NOT SURE IF THIS ACTUALLY RUNS, CAN WE RECEIVE THIS TYPE RESPONSE FOR A JSON BATCH?
Expand All @@ -790,19 +807,17 @@ def should_retry(self, e: Exception) -> bool:
return self.is_single_multicall

@set_done
async def spoof_response(self, response: List[RawResponse]) -> None:
async def spoof_response(self, response: List[RawResponse], calls: List[RPCRequest]) -> None:
# This means we got results. That doesn't mean they're good, but we got 'em.

if self.controller._sort_calls:
# NOTE: these providers don't always return batch results in the correct ordering
# NOTE: is it maybe because they
calls = sorted(self.calls, key=lambda call: call.uid)
calls = sorted(calls, key=lambda call: call.uid)
for i, (call, raw) in enumerate(zip(calls, response)):
# TODO: make sure this doesn't ever raise and then delete it
decoded = raw.decode()
assert call.uid == decoded.id, (i, call, decoded, response, [[call.uid for call in calls], [raw.decode() for raw in response]])
else:
calls = self.calls

for r in await asyncio.gather(*[call.spoof_response(raw) for call, raw in zip(calls, response)], return_exceptions=True):
# NOTE: By doing this with the exceptions we allow any successful calls to get their results sooner
Expand Down
38 changes: 23 additions & 15 deletions dank_mids/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,24 +128,32 @@ def __repr__(self) -> str:
return f"<DankMiddlewareController instance={self._instance} chain={self.chain_id} endpoint={self.endpoint}>"

async def __call__(self, method: RPCEndpoint, params: Any) -> RPCResponse:
with suppress(KeyError):
# some methods go thru a SmartProcessingQueue, we try this first
try:
queue = self.method_queues[method]
return await queue(self, method, params)
except TypeError as e:
if "unhashable type" in str(e):
return await queue(self, method, _helpers._make_hashable(params))
raise e


# eth_call go thru a specialized Semaphore and other methods pass thru unblocked
logger.debug(f'making {self.request_type.__name__} {method} with params {params}')
if method != "eth_call":
return await RPCRequest(self, method, params)
async with self.eth_call_semaphores[params[1]]:
if params[0]["to"] not in self.no_multicall:
if method == "eth_call":
async with self.eth_call_semaphores[params[1]]:
# create a strong ref to the call that will be held until the caller completes or is cancelled
logger.debug(f'making {self.request_type.__name__} {method} with params {params}')
if params[0]["to"] in self.no_multicall:
return await RPCRequest(self, method, params)
return await eth_call(self, params)

# some methods go thru a SmartProcessingQueue, we check those next
queue = self.method_queues[method]
logger.debug(f'making {self.request_type.__name__} {method} with params {params}')

# no queue, we can make the request normally
if queue is None:
return await RPCRequest(self, method, params)

# queue found, queue up the call and await the future
try:
# NOTE: is this a strong enough ref?
return await queue(self, method, params)
except TypeError as e:
if "unhashable type" not in str(e):
raise e
return await queue(self, method, _helpers._make_hashable(params))

@eth_retry.auto_retry
async def make_request(self, method: str, params: List[Any], request_id: Optional[int] = None) -> RawResponse:
Expand Down
2 changes: 1 addition & 1 deletion dank_mids/helpers/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DEC
if isinstance(kwargs.get('data'), PartialRequest):
logger.debug("making request for %s", kwargs['data'])
kwargs['data'] = _codec.encode(kwargs['data'])
logger.debug("making request with (args, kwargs): (%s %s)", tuple(endpoint, *args), kwargs)
logger.debug("making request to %s with (args, kwargs): (%s %s)", endpoint, args, kwargs)

# Try the request until success or 5 failures.
tried = 0
Expand Down
4 changes: 2 additions & 2 deletions dank_mids/semaphores.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __init__(self, controller: "DankMiddlewareController") -> None:
}
self.keys = self.method_queues.keys()
@functools.lru_cache(maxsize=None)
def __getitem__(self, method: RPCEndpoint) -> a_sync.SmartProcessingQueue:
def __getitem__(self, method: RPCEndpoint) -> Optional[a_sync.SmartProcessingQueue]:
try:
return next(self.method_queues[key] for key in self.keys if key in method)
except StopIteration:
raise KeyError(method) from None
return None

0 comments on commit 7c0235a

Please sign in to comment.