diff --git a/kag/common/checkpointer/bin_checkpointer.py b/kag/common/checkpointer/bin_checkpointer.py index df1d7c96..980917fa 100644 --- a/kag/common/checkpointer/bin_checkpointer.py +++ b/kag/common/checkpointer/bin_checkpointer.py @@ -13,6 +13,8 @@ import logging import transaction import threading + +import BTrees.OOBTree from ZODB import DB from ZODB.FileStorage import FileStorage from kag.common.checkpointer.base import CheckPointer @@ -115,6 +117,8 @@ def open(self): with self._lock: storage = FileStorage(self._ckpt_file_path) db = DB(storage) + with db.transaction() as conn: + conn.root.data = BTrees.OOBTree.BTree() return db def read_from_ckpt(self, key): @@ -129,7 +133,7 @@ def read_from_ckpt(self, key): """ with self._lock: with self._ckpt.transaction() as conn: - obj = conn.root().get(key, None) + obj = conn.root.data.get(key, None) return obj def write_to_ckpt(self, key, value): @@ -143,7 +147,7 @@ def write_to_ckpt(self, key, value): with self._lock: try: with self._ckpt.transaction() as conn: - conn.root()[key] = value + conn.root.data[key] = value except Exception as e: logger.warn(f"failed to write checkpoint {key} to db, info: {e}") @@ -171,9 +175,9 @@ def exists(self, key): """ with self._lock: with self._ckpt.transaction() as conn: - return key in conn.root() + return key in conn.root.data def size(self): with self._lock: with self._ckpt.transaction() as conn: - return len(conn.root()) + return len(conn.root.data)