diff --git a/kag/common/checkpointer/base.py b/kag/common/checkpointer/base.py index f1963e0a..83e353bc 100644 --- a/kag/common/checkpointer/base.py +++ b/kag/common/checkpointer/base.py @@ -47,9 +47,9 @@ def __init__(self, ckpt_dir: str, rank: int = 0, world_size: int = 1): self._ckpt_dir, CheckPointer.ckpt_file_name.format(rank, world_size) ) self._ckpt = self.open() - if len(self._ckpt) > 0: + if self.size() > 0: print( - f"{bold}{red}Existing checkpoint found in {self._ckpt_dir}, with {len(self._ckpt)} records.{reset}" + f"{bold}{red}Existing checkpoint found in {self._ckpt_dir}, with {self.size()} records.{reset}" ) def open(self): @@ -100,3 +100,13 @@ def exists(self, key): bool: True if the key exists in the checkpoint, False otherwise. """ raise NotImplementedError("close not implemented yet.") + + def size(self): + """ + Return the number of records in the checkpoint file. + + Returns: + int: the number of records in the checkpoint file. + """ + + raise NotImplementedError("size not implemented yet.") diff --git a/kag/common/checkpointer/bin_checkpointer.py b/kag/common/checkpointer/bin_checkpointer.py index 142569c0..9df86fc9 100644 --- a/kag/common/checkpointer/bin_checkpointer.py +++ b/kag/common/checkpointer/bin_checkpointer.py @@ -80,6 +80,9 @@ def close(self): self._ckpt.sync() self._ckpt.close() + def size(self): + return len(self._ckpt) + @CheckPointer.register("zodb") class ZODBCheckPointer(CheckPointer): @@ -109,9 +112,8 @@ def open(self): """ storage = FileStorage(self._ckpt_file_path) - self._db = DB(storage) - self._connection = self._db.open() - return self._connection.root() + db = DB(storage) + return db def read_from_ckpt(self, key): """ @@ -123,7 +125,9 @@ def read_from_ckpt(self, key): Returns: Any: The value associated with the key in the checkpoint. """ - return self._ckpt.get(key, None) + with self._ckpt.transaction() as conn: + obj = conn.root().get(key, None) + return obj def write_to_ckpt(self, key, value): """ @@ -133,11 +137,11 @@ def write_to_ckpt(self, key, value): key (str): The key to store the value in the checkpoint. value (Any): The value to be stored in the checkpoint. """ - self._ckpt[key] = value + try: - transaction.commit() + with self._ckpt.transaction() as conn: + conn.root()[key] = value except Exception as e: - transaction.abort() logger.warn(f"failed to write checkpoint {key} to db, info: {e}") def close(self): @@ -148,10 +152,8 @@ def close(self): transaction.commit() except: transaction.abort() - if self._connection is not None: - self._connection.close() - if self._db is not None: - self._db.close() + if self._ckpt is not None: + self._ckpt.close() def exists(self, key): """ @@ -163,4 +165,9 @@ def exists(self, key): Returns: bool: True if the key exists in the checkpoint, False otherwise. """ - return key in self._ckpt + with self._ckpt.transaction() as conn: + return key in conn.root() + + def size(self): + with self._ckpt.transaction() as conn: + return len(conn.root()) diff --git a/kag/common/checkpointer/txt_checkpointer.py b/kag/common/checkpointer/txt_checkpointer.py index 357e055e..90263613 100644 --- a/kag/common/checkpointer/txt_checkpointer.py +++ b/kag/common/checkpointer/txt_checkpointer.py @@ -84,3 +84,6 @@ def close(self): """ self._writer.flush() self._writer.close() + + def size(self): + return len(self._ckpt)