Skip to content

Commit

Permalink
add thread for zodb IO
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzhongshu123 committed Dec 16, 2024
1 parent 263b073 commit b78b7a3
Showing 1 changed file with 29 additions and 23 deletions.
52 changes: 29 additions & 23 deletions kag/common/checkpointer/bin_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import shelve
import logging
import transaction
import threading
from ZODB import DB
from ZODB.FileStorage import FileStorage
from kag.common.checkpointer.base import CheckPointer
Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(self, ckpt_dir: str, rank: int = 0, world_size: int = 1):
rank (int): The rank of the current process (default is 0).
world_size (int): The total number of processes in the distributed environment (default is 1).
"""
self._lock = threading.Lock()
super().__init__(ckpt_dir, rank, world_size)

def open(self):
Expand All @@ -110,10 +112,10 @@ def open(self):
Returns:
dict: The root object of the ZODB database, which is a dictionary-like object.
"""

storage = FileStorage(self._ckpt_file_path)
db = DB(storage)
return db
with self._lock:
storage = FileStorage(self._ckpt_file_path)
db = DB(storage)
return db

def read_from_ckpt(self, key):
"""
Expand All @@ -125,9 +127,10 @@ def read_from_ckpt(self, key):
Returns:
Any: The value associated with the key in the checkpoint.
"""
with self._ckpt.transaction() as conn:
obj = conn.root().get(key, None)
return obj
with self._lock:
with self._ckpt.transaction() as conn:
obj = conn.root().get(key, None)
return obj

def write_to_ckpt(self, key, value):
"""
Expand All @@ -137,23 +140,24 @@ def write_to_ckpt(self, key, value):
key (str): The key to store the value in the checkpoint.
value (Any): The value to be stored in the checkpoint.
"""

try:
with self._ckpt.transaction() as conn:
conn.root()[key] = value
except Exception as e:
logger.warn(f"failed to write checkpoint {key} to db, info: {e}")
with self._lock:
try:
with self._ckpt.transaction() as conn:
conn.root()[key] = value
except Exception as e:
logger.warn(f"failed to write checkpoint {key} to db, info: {e}")

def close(self):
"""
Closes the ZODB database connection.
"""
try:
transaction.commit()
except:
transaction.abort()
if self._ckpt is not None:
self._ckpt.close()
with self._lock:
try:
transaction.commit()
except:
transaction.abort()
if self._ckpt is not None:
self._ckpt.close()

def exists(self, key):
"""
Expand All @@ -165,9 +169,11 @@ def exists(self, key):
Returns:
bool: True if the key exists in the checkpoint, False otherwise.
"""
with self._ckpt.transaction() as conn:
return key in conn.root()
with self._lock:
with self._ckpt.transaction() as conn:
return key in conn.root()

def size(self):
with self._ckpt.transaction() as conn:
return len(conn.root())
with self._lock:
with self._ckpt.transaction() as conn:
return len(conn.root())

0 comments on commit b78b7a3

Please sign in to comment.