Skip to content

Commit

Permalink
Reduce resource consumption (#15)
Browse files Browse the repository at this point in the history
* feat: eliminate repetitive dict lookup

* feat: skip unnecessary check

* feat: reduce resource consumption
  • Loading branch information
BobTheBuidler authored Aug 9, 2022
1 parent 4029508 commit bf7edf7
Showing 1 changed file with 114 additions and 89 deletions.
203 changes: 114 additions & 89 deletions dank_mids/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from functools import lru_cache
from time import time
from typing import Any, Dict, List, Literal, Optional, Set, Union
from typing import Any, Generator, List, Literal, Optional, Set, Union

import multicall
from aiohttp import RequestInfo
Expand Down Expand Up @@ -44,7 +44,7 @@ def _err_msg(e: Exception) -> str:
raise e
return err_msg

def start_caller_event_loop(loop: asyncio.BaseEventLoop) -> None:
def start_worker_event_loop(loop: asyncio.BaseEventLoop) -> None:
"""
Used to start a second event loop in a separate thread which is used to reduce congestion on the main event loop.
This allows dank_mids to better communicate with your node while you abuse it with heavy loads.
Expand All @@ -62,12 +62,12 @@ def __init__(self, w3: Web3) -> None:
self.sync_w3.middleware_onion.clear()
self.sync_w3.provider.middlewares = tuple()
self.DO_NOT_BATCH: Set[str] = set()
self.pending_calls: defaultdict = defaultdict(dict)
self.in_process_calls: defaultdict = defaultdict(dict)
self.completed_calls: defaultdict = defaultdict(dict)
self.pending_calls: List[BatchedCall] = []
self.num_pending_calls: int = 0
self.batcher = multicall.multicall.batcher
self.caller_event_loop = asyncio.new_event_loop()
threading.Thread(target=lambda: start_caller_event_loop(self.caller_event_loop)).start()
self.is_running = False
self.worker_event_loop = asyncio.new_event_loop()
threading.Thread(target=lambda: start_worker_event_loop(self.worker_event_loop)).start()
self._bid: int = 0 # batch id
self._mid: int = 0 # multicall id
self._cid: int = 0 # call id
Expand All @@ -87,9 +87,8 @@ def __repr__(self) -> str:
async def __call__(self, params: Any) -> RPCResponse:
if not self._is_configured:
await self._setup()
block = params[1]
cid = await self.add_to_queue(params)
return await self.await_response(block,cid)
call = await self.add_to_queue(params)
return await call

@property
def next_bid(self) -> int:
Expand All @@ -107,6 +106,15 @@ def next_cid(self) -> int:
self._checkpoint = time()
return self._increment('cid')


async def taskmaster_loop(self) -> None:
self.is_running = True
while self.pending_calls:
await asyncio.sleep(0)
if (self.loop_is_ready or self.queue_is_full):
await self.execute_multicall()
self.is_running = False

@sort_lazy_logger
def should_batch(self, method: RPCEndpoint, params: Any) -> bool:
""" Determines whether or not a call should be passed to the DankMiddlewareController. """
Expand All @@ -118,38 +126,24 @@ def should_batch(self, method: RPCEndpoint, params: Any) -> bool:
return False
return True

async def add_to_queue(self, params: Any) -> int:
async def add_to_queue(self, params: Any) -> "BatchedCall":
""" Adds a call to the DankMiddlewareContoller's `pending_calls`. """
cid = self.next_cid
block = params[1]
while self._pools_closed:
await asyncio.sleep(0)
self.pending_calls[block][cid] = params
demo_logger.info(f'added to queue (cid: {cid})')
return cid

async def await_response(self, block: str, cid: int) -> RPCResponse:
while cid not in self.completed_calls[block]:
if cid in self.in_process_calls[block]:
while cid not in self.completed_calls[block]:
await asyncio.sleep(0)
return self.fetch_response(block,cid)
assert cid in self.pending_calls[block] or cid in self.in_process_calls[block], f"Something went wrong, call{cid} is missing from `pending_calls`."
if self.loop_is_ready() or self.queue_is_full(block):
asyncio.run_coroutine_threadsafe(self.execute_multicall(), self.caller_event_loop).result()
await asyncio.sleep(LOOP_INTERVAL)
return self.fetch_response(block,cid)
return BatchedCall(self, params)

@property
def loop_is_ready(self) -> bool:
return time() - self._checkpoint > LOOP_INTERVAL

def queue_is_full(self, block) -> bool:
return len(self.pending_calls[block]) >= self.batcher.step * 25

def fetch_response(self, block: str, cid: int) -> RPCResponse:
return self.completed_calls[block].pop(cid)
@property
def queue_is_full(self) -> bool:
return len(self.pending_calls) >= self.batcher.step * 25

async def execute_multicall(self) -> None:
asyncio.run_coroutine_threadsafe(self._execute_multicall(), self.worker_event_loop).result()

async def _execute_multicall(self) -> None:
i = 0
while self._cid_lock.locked():
if i // 500 == int(i // 500):
Expand All @@ -158,66 +152,58 @@ async def execute_multicall(self) -> None:
await asyncio.sleep(.1)
self._pools_closed = True
with self._cid_lock:
calls_to_exec = []
blocks = list(self.pending_calls.keys())
for block in blocks:
calls = self.pending_calls.pop(block)
for cid, params in calls.items():
self.in_process_calls[block][cid] = params
calls_to_exec.append((block, calls))
calls_to_exec = defaultdict(list)
for call in self.pending_calls:
calls_to_exec[call.block].append(call)
call.started = True
self.pending_calls.clear()
self.num_pending_calls = 0
self._pools_closed = False
demo_logger.info(f'executing multicall (current cid: {self._cid})')
await gather([self.process_block(block, calls) for block, calls in calls_to_exec])
await gather([self.process_block(block, calls) for block, calls in calls_to_exec.items()])
demo_logger.info('multicall complete')

async def process_block(self, block: str, calls: Dict[int,List]) -> None:
calls = [[cid, [params[0]['to'], HexBytes(params[0]['data'])]] for cid, params in calls.items()]
async def process_block(self, block: str, calls: List["BatchedCall"]) -> None:
demo_logger.info(f'executing {len(calls)} calls for block {block}')
batches = self.batcher.batch_calls(calls, self.batcher.step)
await gather([self.process_batch(batch,block) for batch in batches])

async def process_batch(self, batch: List, block: str, bid: Optional[int] = None) -> None:
async def process_batch(self, batch: List["BatchedCall"], block: str, bid: Optional[int] = None) -> None:
if bid is None:
bid = self.next_bid
mid = self.next_mid
demo_logger.info(f'tryBlockAndAggregate {mid} for batch {bid} starting')
batch = dict(batch)
cids = list(batch.keys())
inputs = list(batch.values())
try:
_, _, response = await self._multicall_for_block(block).coroutine([False, inputs])
_, _, response = await self._multicall_for_block(block).coroutine([False, [[call.target, call.calldata] for call in batch]])
demo_logger.info(f'tryBlockAndAggregate {mid} for batch {bid} complete')
await gather([
self.spoof_response(cid, block, params, data)
for cid, params, (_, data) in zip(cids, inputs, response)
])
await gather([call.spoof_response(data) for call, (_, data) in zip(batch, response)])

except Exception as e:
if len(inputs) == 1:
await self.spoof_response(cids[0], block, inputs[0], e)
if len(batch) == 1:
await batch[0].spoof_response(e)
return
elif "out of gas" in str(e):
# TODO Remember which contracts/calls are gas guzzlers
main_logger.debug('out of gas. cut in half, trying again')
elif any(err in str(e).lower() for err in ["connection reset by peer","request entity too large","server disconnected","execution aborted (timeout = 5s)"]):
main_logger.debug('dank too loud, trying again')
new_step = round(len(inputs) * 0.99) if len(inputs) >= 100 else len(inputs) - 1
num_calls = len(batch)
new_step = round(num_calls * 0.99) if num_calls >= 100 else num_calls - 1
# We need this check because one of the other multicalls in a batch might have already reduced `self.batcher.step`
if new_step < self.batcher.step:
old_step = self.batcher.step
self.batcher.step = new_step
main_logger.warning(f'Multicall batch size reduced from {old_step} to {new_step}. The failed batch had {len(inputs)} calls.')
main_logger.warning(f'Multicall batch size reduced from {old_step} to {new_step}. The failed batch had {len(batch)} calls.')
else:
main_logger.warning(f"unexpected exception: {type(e)} {str(e)}")

batches = [[cid, input] for cid, input in zip(cids, inputs)]
halfpoint = len(batches) // 2
halfpoint = len(batch) // 2
await gather([
self.process_batch(batches[:halfpoint], block, str(bid)+"_0"),
self.process_batch(batches[halfpoint:], block, str(bid)+"_1"),
self.process_batch(batch[:halfpoint], block, str(bid)+"_0"),
self.process_batch(batch[halfpoint:], block, str(bid)+"_1"),
])
demo_logger.info(f'tryBlockAndAggregate {mid} for batch {bid} complete')

@lru_cache(maxsize=None)
def _multicall_for_block(self, block: str) -> multicall.Call:
return multicall.Call(
Expand All @@ -228,34 +214,6 @@ def _multicall_for_block(self, block: str) -> multicall.Call:
gas_limit=GAS_LIMIT,
state_override_code=OVERRIDE_CODE,
)

async def spoof_response(self, cid: int, block: str, params: List, data: Optional[Union[str,Exception]]) -> None:
""" Creates a spoof rpc response for call # `cid` using data returned from multicall2. """
if (
# If multicall failed, try single call to get detailed error.
data is None
# Or, if we got a known "bad" response from the multicall, try single call.
# Could be that target contract does not support multicall2.
or (isinstance(data, bytes) and HexBytes(data).hex() in BAD_HEXES)
):
target, calldata = params
data = await run_in_subprocess(reattempt_call_and_return_exception, target, calldata, block, self.sync_w3)
# We were able to get a usable response from single call.
# Add contract to DO_NOT_BATCH list
if not isinstance(data, Exception):
self.DO_NOT_BATCH.add(target)
if isinstance(data, Exception):
spoof = {"error": {'code': -32000, 'message': _err_msg(data)}}
else:
spoof = {"result": HexBytes(data).hex()}
spoof.update({"id": cid, "jsonrpc": "dank_mids"})
main_logger.debug(f"spoof: {spoof}")
self.completed_calls[block][cid] = spoof

# Pop the call from in_process_calls
# TODO figure out why cids can be missing from in_process_calls when they haven't been popped yet
if cid in self.in_process_calls[block]:
self.in_process_calls[block].pop(cid)

async def _setup(self) -> None:
if self._initializing:
Expand All @@ -278,3 +236,70 @@ def _increment(self, id: Literal["bid","mid","cid"]) -> int:
new = getattr(self, attr) + 1
setattr(self, attr, new)
return new


class ResponseNotReady(Exception):
pass

class BatchedCall:
def __init__(self, controller: DankMiddlewareController, params: Any) -> None:
""" Adds a call to the DankMiddlewareContoller's `pending_calls`. """
self.cid = controller.next_cid
self.block: str = params[1]
self.controller = controller
self.target = params[0]['to']
self.calldata = HexBytes(params[0]['data'])
self.controller.pending_calls.append(self)
self.started = False
self._response: Optional[RPCResponse] = None
demo_logger.info(f'added to queue (cid: {self.cid})')

@property
def is_complete(self) -> bool:
return self._response is not None

@property
def response(self) -> RPCResponse:
if self._response is None:
raise ResponseNotReady(self)
return self._response

def __eq__(self, __o: object) -> bool:
if not isinstance(__o, BatchedCall):
return False
return self.cid == __o.cid

def __hash__(self) -> int:
return self.cid

def __await__(self) -> Generator[Any, None, RPCResponse]:
return self.wait_for_response().__await__()

async def wait_for_response(self) -> RPCResponse:
if not self.controller.is_running:
await self.controller.taskmaster_loop()
while not self.is_complete:
await asyncio.sleep(0)
return self.response

async def spoof_response(self, data: Optional[Union[str,Exception]]) -> None:
""" Creates a spoof rpc response for this BatchedCall instance using data provided by the controller. """
if (
# If multicall failed, try single call to get detailed error.
data is None
# Or, if we got a known "bad" response from the multicall, try single call.
# Could be that target contract does not support multicall2.
or (isinstance(data, bytes) and HexBytes(data).hex() in BAD_HEXES)
):
data = await run_in_subprocess(reattempt_call_and_return_exception, self.target, self.calldata, self.block, self.controller.sync_w3)
# We were able to get a usable response from single call.
# Add contract to DO_NOT_BATCH list
if not isinstance(data, Exception):
self.controller.DO_NOT_BATCH.add(self.target)
if isinstance(data, Exception):
spoof = {"error": {'code': -32000, 'message': _err_msg(data)}}
else:
spoof = {"result": HexBytes(data).hex()}
spoof.update({"id": self.cid, "jsonrpc": "dank_mids"})
main_logger.debug(f"spoof: {spoof}")
self._response = spoof

0 comments on commit bf7edf7

Please sign in to comment.