Skip to content

Commit

Permalink
Merge pull request #127 from zhuzhongshu123/0.6_dev
Browse files Browse the repository at this point in the history
feat(builder): add a checkpointer class with zodb  as backend
  • Loading branch information
zhuzhongshu123 authored Dec 13, 2024
2 parents d878d82 + 433bc50 commit 03f5153
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 3 deletions.
4 changes: 3 additions & 1 deletion kag/common/checkpointer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def __init__(self, ckpt_dir: str, rank: int = 0, world_size: int = 1):
)
self._ckpt = self.open()
if len(self._ckpt) > 0:
print(f"{bold}{red}Existing checkpoint found in {self._ckpt_dir}.{reset}")
print(
f"{bold}{red}Existing checkpoint found in {self._ckpt_dir}, with {len(self._ckpt)} records.{reset}"
)

def open(self):
"""
Expand Down
91 changes: 91 additions & 0 deletions kag/common/checkpointer/bin_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import shelve
import logging
import transaction
from ZODB import DB
from ZODB.FileStorage import FileStorage
from kag.common.checkpointer.base import CheckPointer

logger = logging.getLogger()


@CheckPointer.register("bin")
class BinCheckPointer(CheckPointer):
Expand Down Expand Up @@ -73,3 +79,88 @@ def close(self):
"""
self._ckpt.sync()
self._ckpt.close()


@CheckPointer.register("zodb")
class ZODBCheckPointer(CheckPointer):
"""
A CheckPointer implementation that uses ZODB as the underlying storage.
This class provides methods to open, read, write, and close checkpoints using ZODB.
"""

def __init__(self, ckpt_dir: str, rank: int = 0, world_size: int = 1):
"""
Initializes the ZODBCheckPointer with the given checkpoint directory, rank, and world size.
Args:
ckpt_dir (str): The directory where checkpoint files are stored.
rank (int): The rank of the current process (default is 0).
world_size (int): The total number of processes in the distributed environment (default is 1).
"""
super().__init__(ckpt_dir, rank, world_size)

def open(self):
"""
Opens the ZODB database and returns the root object for checkpoint storage.
Returns:
dict: The root object of the ZODB database, which is a dictionary-like object.
"""

storage = FileStorage(self._ckpt_file_path)
self._db = DB(storage)
self._connection = self._db.open()
return self._connection.root()

def read_from_ckpt(self, key):
"""
Reads a value from the checkpoint using the specified key.
Args:
key (str): The key to retrieve the value from the checkpoint.
Returns:
Any: The value associated with the key in the checkpoint.
"""
return self._ckpt.get(key, None)

def write_to_ckpt(self, key, value):
"""
Writes a value to the checkpoint using the specified key.
Args:
key (str): The key to store the value in the checkpoint.
value (Any): The value to be stored in the checkpoint.
"""
self._ckpt[key] = value
try:
transaction.commit()
except Exception as e:
transaction.abort()
logger.warn(f"failed to write checkpoint {key} to db, info: {e}")

def close(self):
"""
Closes the ZODB database connection.
"""
try:
transaction.commit()
except:
transaction.abort()
if self._connection is not None:
self._connection.close()
if self._db is not None:
self._db.close()

def exists(self, key):
"""
Checks if a key exists in the checkpoint.
Args:
key (str): The key to check for existence in the checkpoint.
Returns:
bool: True if the key exists in the checkpoint, False otherwise.
"""
return key in self._ckpt
5 changes: 3 additions & 2 deletions kag/interface/builder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, *args, **kwargs):

self.checkpointer: CheckPointer = CheckPointer.from_config(
{
"type": "bin",
"type": "zodb",
"ckpt_dir": self.ckpt_dir,
"rank": rank,
"world_size": world_size,
Expand Down Expand Up @@ -98,7 +98,8 @@ def invoke(self, input: Input, **kwargs) -> List[Output]:

if input_key and self.checkpointer.exists(input_key):
out = self.checkpointer.read_from_ckpt(input_key)
return out
if out is not None:
return out
output = self._invoke(input, **kwargs)
if input_key:
self.checkpointer.write_to_ckpt(input_key, output)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ ollama
tenacity
pyhocon
scikit-learn
zodb

0 comments on commit 03f5153

Please sign in to comment.