Skip to content

Commit

Permalink
fix contention and refactor database initialization
Browse files Browse the repository at this point in the history
previously there was a bug where keys might overlap between chain ids
(unlikely, but possible). this commit refactors so that each chainid
gets its own db.
  • Loading branch information
charles-cooper committed Jan 10, 2025
1 parent e3a15a5 commit 1b5bb70
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 73 deletions.
132 changes: 94 additions & 38 deletions boa/util/sqlitedb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import sqlite3
import time
from pathlib import Path
Expand All @@ -8,8 +9,8 @@
_ONE_MONTH = 30 * 24 * 3600


def get_current_time() -> int:
return int(time.time())
def get_current_time() -> float:
return time.time()


class SqliteCache(BaseDB):
Expand All @@ -20,82 +21,137 @@ class SqliteCache(BaseDB):

_GLOBAL = None

_CREATE_CMDS = """
_CREATE_CMDS = [
"""
pragma journal_mode=wal
""",
"""
CREATE TABLE IF NOT EXISTS kv_store (
key TEXT PRIMARY KEY, value BLOB, expires_at NUM
);
CREATE INDEX IF NOT EXISTS expires_at_index ON kv_store(expires_at)
""".split(
";"
key TEXT PRIMARY KEY, value BLOB, expires_at float
)
""",
"""
CREATE INDEX IF NOT EXISTS expires_at_index ON kv_store(expires_at)
""",
]

# flush at least once per second
_MAX_FLUSH_TIME = 1.0

def __init__(self, db_path: Path | str, ttl: int = _ONE_MONTH) -> None:
def __init__(self, db_path: Path | str, ttl: float = _ONE_MONTH) -> None:
if db_path != ":memory:": # sqlite magic path
db_path = Path(db_path)
db_path.parent.mkdir(parents=True, exist_ok=True)

# once 3.12 is min version, use autocommit=True
self.db: sqlite3.Connection = sqlite3.connect(db_path)
for cmd in self.__class__._CREATE_CMDS:
self.db.execute(cmd)
self.db: sqlite3.Connection = sqlite3.connect(
db_path, timeout=0.0, isolation_level=None
)
with self.acquire_write_lock():
for cmd in self.__class__._CREATE_CMDS:
self.db.execute(cmd)

# ttl of cache entries in seconds
self.ttl: int = ttl
# ttl = 100
self.ttl: float = float(ttl)

self.gc()

self._last_flush = get_current_time()
self._expiry_updates: list[tuple[float, bytes]] = []

def gc(self):
current_time = get_current_time()
self.db.execute("DELETE FROM kv_store WHERE expires_at < ?", (current_time,))
self.db.commit()
with self.acquire_write_lock():
current_time = get_current_time()
self.db.execute(
"DELETE FROM kv_store WHERE expires_at < ?", (current_time,)
)

def __del__(self):
self._flush()

def _flush_condition(self):
if len(self._expiry_updates) == 0:
return False

next_flush = self._last_flush + self._MAX_FLUSH_TIME
return len(self._expiry_updates) > 1000 or get_current_time() > next_flush

def _flush(self):
with self.acquire_write_lock():
query_string = """
UPDATE kv_store
SET expires_at=?
WHERE key=?
"""
self.db.executemany(query_string, self._expiry_updates)
self._expiry_updates = []

@contextlib.contextmanager
def acquire_write_lock(self):
while True:
try:
self.db.execute("BEGIN IMMEDIATE")
break
except sqlite3.OperationalError:
# sleep 10 micros
time.sleep(1e-4)
continue
try:
yield
self.db.commit()
except Exception:
self.db.rollback()

@classmethod
# Creates db as a class variable to avoid level db lock error
# create the singleton db object
# Creates db as a singleton class variable
def create(cls, *args, **kwargs):
if cls._GLOBAL is None:
cls._GLOBAL = cls(*args, **kwargs)
return cls._GLOBAL

def get_expiry_ts(self):
def get_expiry_ts(self) -> float:
current_time = get_current_time()
return current_time + self.ttl

def __getitem__(self, key: bytes) -> bytes:
query_string = """
UPDATE kv_store
SET expires_at=?
SELECT value, expires_at FROM kv_store
WHERE key=?
RETURNING value
"""
expiry_ts = self.get_expiry_ts()
res = self.db.execute(query_string, (expiry_ts, key)).fetchone()
res = self.db.execute(query_string, (key,)).fetchone()
if res is None:
raise KeyError(key)
(val,) = res
self.db.commit()

val, expires_at = res

# to reduce contention, instead of updating the expiry every
# time, batch the expiry updates.
if expires_at - get_current_time() > self.ttl / 100:
new_expiry_ts = self.get_expiry_ts()
self._expiry_updates.append((new_expiry_ts, key))
if self._flush_condition():
self._flush()

return val

def __setitem__(self, key: bytes, value: bytes) -> None:
query_string = """
INSERT INTO kv_store(key, value, expires_at) VALUES (?,?,?)
ON CONFLICT DO UPDATE
SET key=excluded.key,
value=excluded.value,
ON CONFLICT
SET value=excluded.value,
expires_at=excluded.expires_at
"""
expiry_ts = self.get_expiry_ts()
self.db.execute(query_string, (key, value, expiry_ts))
self.db.commit()
with self.acquire_write_lock():
expiry_ts = self.get_expiry_ts()
self.db.execute(query_string, (key, value, expiry_ts))

def _exists(self, key: bytes) -> bool:
res = self.db.execute(
"SELECT count(*) FROM kv_store WHERE key=?", (key,)
).fetchone()
query_string = "SELECT count(*) FROM kv_store WHERE key=?"
(res,) = self.db.execute(query_string, (key,)).fetchone()
return bool(res)

def __delitem__(self, key: bytes) -> None:
res = self.db.execute("DELETE FROM kv_store WHERE key=?", (key,))
with self.acquire_write_lock():
res = self.db.execute("DELETE FROM kv_store WHERE key=?", (key,))
if res.rowcount == 0:
raise KeyError(key)
self.db.commit()
92 changes: 59 additions & 33 deletions boa/vm/fork.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import pickle
from pathlib import Path
from typing import Any, Type

import rlp
Expand All @@ -12,14 +14,14 @@
from eth_utils import int_to_big_endian, to_canonical_address, to_checksum_address
from requests import HTTPError

from boa.rpc import RPC, RPCError, fixup_dict, json, to_bytes, to_hex, to_int
from boa.rpc import RPC, RPCError, fixup_dict, to_bytes, to_hex, to_int
from boa.util.lrudict import lrudict
from boa.util.sqlitedb import SqliteCache

TIMEOUT = 60 # default timeout for http requests in seconds


DEFAULT_CACHE_DIR = "~/.cache/titanoboa/fork-sqlite.db"
DEFAULT_CACHE_DIR = "~/.cache/titanoboa/fork/"
_PREDEFINED_BLOCKS = {"safe", "latest", "finalized", "pending", "earliest"}


Expand All @@ -28,16 +30,49 @@


class CachingRPC(RPC):
def __init__(self, rpc: RPC, cache_file: str = DEFAULT_CACHE_DIR):
# _loaded is a cache for the constructor.
# reduces fork time after the first fork.
_loaded: dict[tuple[str, int, str], "CachingRPC"] = {}
_pid: int = os.getpid() # so we can detect if our fds are bad

def __new__(cls, rpc, chain_id, cache_dir=None):
if isinstance(rpc, cls):
if rpc._chain_id == chain_id:
return rpc
else:
# unwrap
rpc = rpc._rpc

if os.getpid() != cls._pid:
# we are in a fork. reload everything so that fds are not corrupted
cls._loaded = {}
cls._pid = os.getpid()

if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR

if (rpc.identifier, chain_id, cache_dir) in cls._loaded:
return cls._loaded[(rpc.identifier, chain_id, cache_dir)]

ret = super().__new__(cls)
ret.__init__(rpc, chain_id, cache_dir)
cls._loaded[(rpc.identifier, chain_id, cache_dir)] = ret
return ret

def __init__(self, rpc: RPC, chain_id: int, cache_dir: str = None):
self._rpc = rpc

self._cache_file = cache_file
self._chain_id = chain_id # TODO: check if this is needed

if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR

self._cache_file = self._cache_filepath(cache_dir, chain_id)
self._init_db()

# _loaded is a cache for the constructor.
# reduces fork time after the first fork.
_loaded: dict[tuple[str, str], "CachingRPC"] = {}
_pid: int = os.getpid() # so we can detect if our fds are bad
@classmethod
def _cache_filepath(cls, cache_dir, chain_id):
return Path(cache_dir) / f"chainid_{hex(chain_id)}-sqlite.db"

def _init_db(self):
if self._cache_file is not None:
Expand All @@ -58,35 +93,19 @@ def identifier(self) -> str:
def name(self):
return self._rpc.name

def __new__(cls, rpc, cache_file=DEFAULT_CACHE_DIR):
if isinstance(rpc, cls):
return rpc

if os.getpid() != cls._pid:
# we are in a fork. reload everything so that fds are not corrupted
cls._loaded = {}
cls._pid = os.getpid()

if (rpc.identifier, cache_file) in cls._loaded:
return cls._loaded[(rpc.identifier, cache_file)]

ret = super().__new__(cls)
ret.__init__(rpc, cache_file)
cls._loaded[(rpc.identifier, cache_file)] = ret
return ret

# a stupid key for the kv store
def _mk_key(self, method: str, params: Any) -> Any:
return json.dumps({"method": method, "params": params}).encode("utf-8")
return pickle.dumps((method, params))

def fetch(self, method, params):
# cannot dispatch into fetch_multi, doesn't work for debug_traceCall.
key = self._mk_key(method, params)
if key in self._db:
return json.loads(self._db[key])
ret = pickle.loads(self._db[key])
return ret

result = self._rpc.fetch(method, params)
self._db[key] = json.dumps(result).encode("utf-8")
self._db[key] = pickle.dumps(result)
return result

def fetch_uncached(self, method, params):
Expand All @@ -100,7 +119,7 @@ def fetch_multi(self, payload):
for item_ix, (method, params) in enumerate(payload):
key = self._mk_key(method, params)
try:
ret[item_ix] = json.loads(self._db[key])
ret[item_ix] = pickle.loads(self._db[key])
except KeyError:
keys.append((key, item_ix))
batch.append((method, params))
Expand All @@ -111,7 +130,7 @@ def fetch_multi(self, payload):
for result_ix, rpc_result in enumerate(self._rpc.fetch_multi(batch)):
key, item_ix = keys[result_ix]
ret[item_ix] = rpc_result
self._db[key] = json.dumps(rpc_result).encode("utf-8")
self._db[key] = pickle.dumps(rpc_result)

return [ret[i] for i in range(len(ret))]

Expand All @@ -125,12 +144,17 @@ def class_from_rpc(
) -> Type["AccountDBFork"]:
class _ConfiguredAccountDB(AccountDBFork):
def __init__(self, *args, **kwargs2):
caching_rpc = CachingRPC(rpc, **kwargs)
super().__init__(caching_rpc, block_identifier, *args, **kwargs2)
chain_id = int(rpc.fetch_uncached("eth_chainId", []), 16)
caching_rpc = CachingRPC(rpc, chain_id, **kwargs)
super().__init__(
caching_rpc, chain_id, block_identifier, *args, **kwargs2
)

return _ConfiguredAccountDB

def __init__(self, rpc: CachingRPC, block_identifier: str, *args, **kwargs) -> None:
def __init__(
self, rpc: CachingRPC, chain_id: int, block_identifier: str, *args, **kwargs
) -> None:
super().__init__(*args, **kwargs)

self._dontfetch = JournalDB(MemoryDB())
Expand All @@ -140,6 +164,8 @@ def __init__(self, rpc: CachingRPC, block_identifier: str, *args, **kwargs) -> N
if block_identifier not in _PREDEFINED_BLOCKS:
block_identifier = to_hex(block_identifier)

self._chain_id = chain_id

self._block_info = self._rpc.fetch_uncached(
"eth_getBlockByNumber", [block_identifier, False]
)
Expand Down
6 changes: 4 additions & 2 deletions boa/vm/py_evm.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,14 +407,16 @@ def enable_fast_mode(self, flag: bool = True):
else:
unpatch_pyevm_state_object(self.vm.state)

def fork_rpc(self, rpc: RPC, block_identifier: str, force: bool = False, **kwargs):
def fork_rpc(self, rpc: RPC, block_identifier: str, **kwargs):
account_db_class = AccountDBFork.class_from_rpc(rpc, block_identifier, **kwargs)
self._init_vm(account_db_class)

block_info = self.vm.state._account_db._block_info
chain_id = self.vm.state._account_db._chain_id

self.patch.timestamp = int(block_info["timestamp"], 16)
self.patch.block_number = int(block_info["number"], 16)
self.patch.chain_id = int(rpc.fetch("eth_chainId", []), 16)
self.patch.chain_id = chain_id

# placeholder not to fetch all prev hashes
# (NOTE: we should document this)
Expand Down

0 comments on commit 1b5bb70

Please sign in to comment.