Skip to content

Commit

Permalink
fixes for None cachefile
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Jan 10, 2025
1 parent f3c9093 commit be6aff3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
10 changes: 5 additions & 5 deletions boa/util/sqlitedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from eth.db.backends.base import BaseDB

#from vyper.utils import timeit
# from vyper.utils import timeit

# poor man's constant
_ONE_MONTH = 30 * 24 * 3600
Expand Down Expand Up @@ -114,7 +114,7 @@ def _flush_condition(self):
return True
return False

#@timeit("FLUSH")
# @timeit("FLUSH")
def _flush(self, nolock=False):
# set nolock=True if the caller has already acquired a lock.
if len(self._expiry_updates) == 0:
Expand Down Expand Up @@ -178,7 +178,7 @@ def get_expiry_ts(self) -> float:
current_time = get_current_time()
return current_time + self.ttl

#@timeit("CACHE HIT")
# @timeit("CACHE HIT")
def __getitem__(self, key: bytes) -> bytes:
query_string = """
SELECT value, expires_at FROM kv_store
Expand All @@ -201,9 +201,9 @@ def __getitem__(self, key: bytes) -> bytes:

return val

#@timeit("CACHE MISS")
# @timeit("CACHE MISS")
def __setitem__(self, key: bytes, value: bytes) -> None:
#with timeit("CACHE MISS"):
# with timeit("CACHE MISS"):
with self.acquire_write_lock():
query_string = """
INSERT INTO kv_store(key, value, expires_at) VALUES (?,?,?)
Expand Down
20 changes: 11 additions & 9 deletions boa/vm/fork.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pickle
import sys
from pathlib import Path
from typing import Any, Type
from typing import Any, Optional, Type

import rlp
from eth.db.account import AccountDB, keccak
Expand Down Expand Up @@ -36,7 +36,7 @@ class CachingRPC(RPC):
_loaded: dict[tuple[str, int, str], "CachingRPC"] = {}
_pid: int = os.getpid() # so we can detect if our fds are bad

def __new__(cls, rpc, chain_id, debug, cache_dir=None):
def __new__(cls, rpc, chain_id, debug, cache_dir=DEFAULT_CACHE_DIR):
if isinstance(rpc, cls):
if rpc._chain_id == chain_id:
return rpc
Expand All @@ -58,16 +58,17 @@ def __new__(cls, rpc, chain_id, debug, cache_dir=None):
return ret

def __init__(
self, rpc: RPC, chain_id: int, debug: bool = False, cache_dir: str = None
self,
rpc: RPC,
chain_id: int,
debug: bool = False,
cache_dir: Optional[str] = DEFAULT_CACHE_DIR,
):
self._rpc = rpc
self._debug = debug

self._chain_id = chain_id # TODO: check if this is needed

self._cache_file = None
if cache_dir is not None:
self._cache_file = self._cache_filepath(cache_dir, chain_id)
self._cache_dir = cache_dir

self._init_db()

Expand All @@ -76,8 +77,9 @@ def _cache_filepath(cls, cache_dir, chain_id):
return Path(cache_dir) / f"chainid_{hex(chain_id)}-sqlite.db"

def _init_db(self):
if self._cache_file is not None:
cache_file = os.path.expanduser(self._cache_file)
if self._cache_dir is not None:
cache_file = self._cache_filepath(self._cache_dir, self._chain_id)
cache_file = os.path.expanduser(cache_file)
sqlitedb = SqliteCache.create(cache_file)
# use CacheDB as an additional layer over disk
self._db = CacheDB(sqlitedb, cache_size=1024 * 1024) # type: ignore
Expand Down

0 comments on commit be6aff3

Please sign in to comment.