Skip to content

Commit

Permalink
Merge pull request #131 from zhuzhongshu123/0.6_dev
Browse files Browse the repository at this point in the history
fix(common): resolve mutli thread conflict in zodb IO
  • Loading branch information
zhuzhongshu123 authored Dec 16, 2024
2 parents 115515b + 5a9c73e commit 1b147d0
Showing 1 changed file with 33 additions and 23 deletions.
56 changes: 33 additions & 23 deletions kag/common/checkpointer/bin_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import shelve
import logging
import transaction
import threading

import BTrees.OOBTree
from ZODB import DB
from ZODB.FileStorage import FileStorage
from kag.common.checkpointer.base import CheckPointer
Expand Down Expand Up @@ -101,6 +104,7 @@ def __init__(self, ckpt_dir: str, rank: int = 0, world_size: int = 1):
rank (int): The rank of the current process (default is 0).
world_size (int): The total number of processes in the distributed environment (default is 1).
"""
self._lock = threading.Lock()
super().__init__(ckpt_dir, rank, world_size)

def open(self):
Expand All @@ -110,10 +114,12 @@ def open(self):
Returns:
dict: The root object of the ZODB database, which is a dictionary-like object.
"""

storage = FileStorage(self._ckpt_file_path)
db = DB(storage)
return db
with self._lock:
storage = FileStorage(self._ckpt_file_path)
db = DB(storage)
with db.transaction() as conn:
conn.root.data = BTrees.OOBTree.BTree()
return db

def read_from_ckpt(self, key):
"""
Expand All @@ -125,9 +131,10 @@ def read_from_ckpt(self, key):
Returns:
Any: The value associated with the key in the checkpoint.
"""
with self._ckpt.transaction() as conn:
obj = conn.root().get(key, None)
return obj
with self._lock:
with self._ckpt.transaction() as conn:
obj = conn.root.data.get(key, None)
return obj

def write_to_ckpt(self, key, value):
"""
Expand All @@ -137,23 +144,24 @@ def write_to_ckpt(self, key, value):
key (str): The key to store the value in the checkpoint.
value (Any): The value to be stored in the checkpoint.
"""

try:
with self._ckpt.transaction() as conn:
conn.root()[key] = value
except Exception as e:
logger.warn(f"failed to write checkpoint {key} to db, info: {e}")
with self._lock:
try:
with self._ckpt.transaction() as conn:
conn.root.data[key] = value
except Exception as e:
logger.warn(f"failed to write checkpoint {key} to db, info: {e}")

def close(self):
"""
Closes the ZODB database connection.
"""
try:
transaction.commit()
except:
transaction.abort()
if self._ckpt is not None:
self._ckpt.close()
with self._lock:
try:
transaction.commit()
except:
transaction.abort()
if self._ckpt is not None:
self._ckpt.close()

def exists(self, key):
"""
Expand All @@ -165,9 +173,11 @@ def exists(self, key):
Returns:
bool: True if the key exists in the checkpoint, False otherwise.
"""
with self._ckpt.transaction() as conn:
return key in conn.root()
with self._lock:
with self._ckpt.transaction() as conn:
return key in conn.root.data

def size(self):
with self._ckpt.transaction() as conn:
return len(conn.root())
with self._lock:
with self._ckpt.transaction() as conn:
return len(conn.root.data)

0 comments on commit 1b147d0

Please sign in to comment.