From b78b7a39448010ec0c2dd89abf92f2d2f8140a7c Mon Sep 17 00:00:00 2001 From: zhuzhongshu123 Date: Mon, 16 Dec 2024 14:38:56 +0800 Subject: [PATCH] add thread for zodb IO --- kag/common/checkpointer/bin_checkpointer.py | 52 ++++++++++++--------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/kag/common/checkpointer/bin_checkpointer.py b/kag/common/checkpointer/bin_checkpointer.py index 9df86fc9..df1d7c96 100644 --- a/kag/common/checkpointer/bin_checkpointer.py +++ b/kag/common/checkpointer/bin_checkpointer.py @@ -12,6 +12,7 @@ import shelve import logging import transaction +import threading from ZODB import DB from ZODB.FileStorage import FileStorage from kag.common.checkpointer.base import CheckPointer @@ -101,6 +102,7 @@ def __init__(self, ckpt_dir: str, rank: int = 0, world_size: int = 1): rank (int): The rank of the current process (default is 0). world_size (int): The total number of processes in the distributed environment (default is 1). """ + self._lock = threading.Lock() super().__init__(ckpt_dir, rank, world_size) def open(self): @@ -110,10 +112,10 @@ def open(self): Returns: dict: The root object of the ZODB database, which is a dictionary-like object. """ - - storage = FileStorage(self._ckpt_file_path) - db = DB(storage) - return db + with self._lock: + storage = FileStorage(self._ckpt_file_path) + db = DB(storage) + return db def read_from_ckpt(self, key): """ @@ -125,9 +127,10 @@ def read_from_ckpt(self, key): Returns: Any: The value associated with the key in the checkpoint. """ - with self._ckpt.transaction() as conn: - obj = conn.root().get(key, None) - return obj + with self._lock: + with self._ckpt.transaction() as conn: + obj = conn.root().get(key, None) + return obj def write_to_ckpt(self, key, value): """ @@ -137,23 +140,24 @@ def write_to_ckpt(self, key, value): key (str): The key to store the value in the checkpoint. value (Any): The value to be stored in the checkpoint. """ - - try: - with self._ckpt.transaction() as conn: - conn.root()[key] = value - except Exception as e: - logger.warn(f"failed to write checkpoint {key} to db, info: {e}") + with self._lock: + try: + with self._ckpt.transaction() as conn: + conn.root()[key] = value + except Exception as e: + logger.warn(f"failed to write checkpoint {key} to db, info: {e}") def close(self): """ Closes the ZODB database connection. """ - try: - transaction.commit() - except: - transaction.abort() - if self._ckpt is not None: - self._ckpt.close() + with self._lock: + try: + transaction.commit() + except: + transaction.abort() + if self._ckpt is not None: + self._ckpt.close() def exists(self, key): """ @@ -165,9 +169,11 @@ def exists(self, key): Returns: bool: True if the key exists in the checkpoint, False otherwise. """ - with self._ckpt.transaction() as conn: - return key in conn.root() + with self._lock: + with self._ckpt.transaction() as conn: + return key in conn.root() def size(self): - with self._ckpt.transaction() as conn: - return len(conn.root()) + with self._lock: + with self._ckpt.transaction() as conn: + return len(conn.root())