Skip to content

Commit

Permalink
fix zodb based checkpointer
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzhongshu123 committed Dec 16, 2024
1 parent 433bc50 commit 263b073
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
14 changes: 12 additions & 2 deletions kag/common/checkpointer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def __init__(self, ckpt_dir: str, rank: int = 0, world_size: int = 1):
self._ckpt_dir, CheckPointer.ckpt_file_name.format(rank, world_size)
)
self._ckpt = self.open()
if len(self._ckpt) > 0:
if self.size() > 0:
print(
f"{bold}{red}Existing checkpoint found in {self._ckpt_dir}, with {len(self._ckpt)} records.{reset}"
f"{bold}{red}Existing checkpoint found in {self._ckpt_dir}, with {self.size()} records.{reset}"
)

def open(self):
Expand Down Expand Up @@ -100,3 +100,13 @@ def exists(self, key):
bool: True if the key exists in the checkpoint, False otherwise.
"""
raise NotImplementedError("close not implemented yet.")

def size(self):
"""
Return the number of records in the checkpoint file.
Returns:
int: the number of records in the checkpoint file.
"""

raise NotImplementedError("size not implemented yet.")
31 changes: 19 additions & 12 deletions kag/common/checkpointer/bin_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def close(self):
self._ckpt.sync()
self._ckpt.close()

def size(self):
return len(self._ckpt)


@CheckPointer.register("zodb")
class ZODBCheckPointer(CheckPointer):
Expand Down Expand Up @@ -109,9 +112,8 @@ def open(self):
"""

storage = FileStorage(self._ckpt_file_path)
self._db = DB(storage)
self._connection = self._db.open()
return self._connection.root()
db = DB(storage)
return db

def read_from_ckpt(self, key):
"""
Expand All @@ -123,7 +125,9 @@ def read_from_ckpt(self, key):
Returns:
Any: The value associated with the key in the checkpoint.
"""
return self._ckpt.get(key, None)
with self._ckpt.transaction() as conn:
obj = conn.root().get(key, None)
return obj

def write_to_ckpt(self, key, value):
"""
Expand All @@ -133,11 +137,11 @@ def write_to_ckpt(self, key, value):
key (str): The key to store the value in the checkpoint.
value (Any): The value to be stored in the checkpoint.
"""
self._ckpt[key] = value

try:
transaction.commit()
with self._ckpt.transaction() as conn:
conn.root()[key] = value
except Exception as e:
transaction.abort()
logger.warn(f"failed to write checkpoint {key} to db, info: {e}")

def close(self):
Expand All @@ -148,10 +152,8 @@ def close(self):
transaction.commit()
except:
transaction.abort()
if self._connection is not None:
self._connection.close()
if self._db is not None:
self._db.close()
if self._ckpt is not None:
self._ckpt.close()

def exists(self, key):
"""
Expand All @@ -163,4 +165,9 @@ def exists(self, key):
Returns:
bool: True if the key exists in the checkpoint, False otherwise.
"""
return key in self._ckpt
with self._ckpt.transaction() as conn:
return key in conn.root()

def size(self):
with self._ckpt.transaction() as conn:
return len(conn.root())
3 changes: 3 additions & 0 deletions kag/common/checkpointer/txt_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,6 @@ def close(self):
"""
self._writer.flush()
self._writer.close()

def size(self):
return len(self._ckpt)

0 comments on commit 263b073

Please sign in to comment.