Skip to content

Commit

Permalink
Merge pull request #73 from zhuzhongshu123/increase_usability
Browse files Browse the repository at this point in the history
feat(kag)Increase usability
  • Loading branch information
zhuzhongshu123 authored Nov 26, 2024
2 parents 3460392 + 9dc8d2a commit 5ac264f
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 21 deletions.
2 changes: 1 addition & 1 deletion kag/builder/component/writer/kg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def invoke(
operation=alter_operation,
lead_to_builder=lead_to_builder,
)
return [None]
return [input]

def _handle(self, input: Dict, alter_operation: str, **kwargs):
"""The calling interface provided for SPGServer."""
Expand Down
86 changes: 71 additions & 15 deletions kag/builder/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,57 @@
import os
import json
import traceback
import logging
from typing import Any, Dict
from datetime import datetime
from tqdm import tqdm

from kag.common.registry import Registrable
from kag.common.utils import reset, bold, red
from kag.interface import KAGBuilderChain, SourceReaderABC
from kag.builder.model.sub_graph import SubGraph
from concurrent.futures import ThreadPoolExecutor, as_completed

logger = logging.getLogger()


def str_abstract(value: str):
if os.path.exists(value):
return os.path.basename(value)
return value[:10]


def dict_abstract(value: Dict):
output = {}
for k, v in value.items():
output[k] = str_abstract(str(v))
return output


def generate_hash_id(value):
if isinstance(value, dict):
sorted_items = sorted(value.items())
key = str(sorted_items)
abstract = dict_abstract(value)
else:
key = value
abstract = str_abstract(value)
if isinstance(key, str):
key = key.encode("utf-8")
hasher = hashlib.sha256()
hasher.update(key)

return hasher.hexdigest()
return hasher.hexdigest(), abstract


class CKPT:
ckpt_file_name = "kag-runner.ckpt"
ckpt_file_name = "kag-runner-{}-{}.ckpt"

def __init__(self, path: str):
def __init__(self, path: str, rank: int = 0, world_size: int = 1):
self.path = path
self.ckpt_file_path = os.path.join(self.path, CKPT.ckpt_file_name)
self.ckpt_file_path = os.path.join(
self.path, CKPT.ckpt_file_name.format(rank, world_size)
)
self._ckpt = set()
if os.path.exists(self.ckpt_file_path):
self.load()
Expand All @@ -58,11 +82,20 @@ def is_processed(self, data_id: str):
def open(self):
self.writer = open(self.ckpt_file_path, "a")

def add(self, data_id: str):
def add(self, data_id: str, data_abstract: str, info: Any):
if self.is_processed(data_id):
return
now = datetime.now()
self.writer.write(json.dumps({"id": data_id, "time": str(now)}))
self.writer.write(
json.dumps(
{
"id": data_id,
"abstract": data_abstract,
"info": info,
"timestamp": str(now),
}
)
)
self.writer.write("\n")
self.writer.flush()

Expand Down Expand Up @@ -90,14 +123,22 @@ def __init__(
if not os.path.exists(self.ckpt_dir):
os.makedirs(self.ckpt_dir, exist_ok=True)

self.ckpt = CKPT(self.ckpt_dir)
self.ckpt = CKPT(
self.ckpt_dir,
self.reader.sharding_info.get_rank(),
self.reader.sharding_info.get_world_size(),
)
msg = (
f"{bold}{red}The checkpoint file is located at {self.ckpt.ckpt_file_path}. "
f"Please access this file to obtain detailed task statistics.{reset}"
)
print(msg)

def invoke(self, input):
def process(chain, data, data_id):
def process(chain, data, data_id, data_abstract):
try:

result = chain.invoke(data, max_workers=self.chain_level_num_paralle)
return result, data_id
return data, data_id, data_abstract, result
except Exception:
traceback.print_exc()
return None
Expand All @@ -106,12 +147,13 @@ def process(chain, data, data_id):
futures = []
print(f"Processing {input}")
with ThreadPoolExecutor(self.num_parallel) as executor:
for item in self.reader.invoke(input):
item_id = generate_hash_id(item)
for item in self.reader.generate(input):
item_id, item_abstract = generate_hash_id(item)
if self.ckpt.is_processed(item_id):
continue
fut = executor.submit(process, self.chain, item, item_id)
fut = executor.submit(process, self.chain, item, item_id, item_abstract)
futures.append(fut)

for future in tqdm(
as_completed(futures),
total=len(futures),
Expand All @@ -120,8 +162,22 @@ def process(chain, data, data_id):
):
result = future.result()
if result is not None:
chain_output, item_id = result
self.ckpt.add(item_id)
item, item_id, item_abstract, chain_output = result
info = {}
num_nodes = 0
num_edges = 0
num_subgraphs = 0
for item in chain_output:
if isinstance(item, SubGraph):
num_nodes += len(item.nodes)
num_edges += len(item.edges)
num_subgraphs += 1
info = {
"num_nodes": num_nodes,
"num_edges": num_edges,
"num_subgraphs": num_subgraphs,
}
self.ckpt.add(item_id, item_abstract, info)
self.ckpt.close()


Expand Down
11 changes: 11 additions & 0 deletions kag/common/registry/registrable.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,17 @@ def resolve_class_name(
"in which case they will be automatically imported correctly."
)

@classmethod
def list_all_registered(cls, with_leaf_classes: bool = False) -> List[str]:
registered = set()
for k, v in Registrable._registry.items():
registered.add(k)
if with_leaf_classes:
if isinstance(v, dict):
for _, register_cls in v.items():
registered.add(register_cls[0])
return sorted(list(registered), key=lambda x: (x.__module__, x.__name__))

@classmethod
def list_available(cls) -> List[str]:
"""List default first if it exists"""
Expand Down
12 changes: 11 additions & 1 deletion kag/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
from jinja2 import Environment, FileSystemLoader, Template
from stat import S_IWUSR as OWNER_WRITE_PERMISSION

reset = "\033[0m"
bold = "\033[1m"
underline = "\033[4m"
red = "\033[31m"
green = "\033[32m"
yellow = "\033[33m"
blue = "\033[34m"
magenta = "\033[35m"
cyan = "\033[36m"
white = "\033[37m"


def append_python_path(path: str) -> bool:
"""
Expand Down Expand Up @@ -100,7 +111,6 @@ def load_json(content):
try:
return json.loads(content)
except json.JSONDecodeError as e:

substr = content[: e.colno - 1]
return json.loads(substr)

Expand Down
9 changes: 9 additions & 0 deletions kag/interface/builder/reader_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ def load_data(self, input: Input, **kwargs) -> List[Output]:

def _generate(self, data):
start, end = self.sharding_info.get_sharding_range(len(data))
worker = (
f"{self.sharding_info.get_rank()}/{self.sharding_info.get_world_size()}"
)
msg = (
f"There are total {len(data)} data to process, worker "
f"{worker} will process range [{start}, {end})"
)

print(msg)
for item in data[start:end]:
yield item

Expand Down
2 changes: 1 addition & 1 deletion kag/interface/builder/writer_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def input_types(self):

@property
def output_types(self):
return None
return SubGraph

@abstractmethod
def invoke(self, input: Input, **kwargs) -> Output:
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/builder/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
def test_ckpt():
ckpt = CKPT("./")
ckpt.open()
ckpt.add("aaaa")
ckpt.add("bbbb")
ckpt.add("cccc")
ckpt.add("aaaa", "aaaa", {})
ckpt.add("bbbb", "bbbb", {"num_nodes": 3})
ckpt.add("cccc", "cccc", {"num_edges": 6})
ckpt.close()

ckpt = CKPT("./")
Expand Down

0 comments on commit 5ac264f

Please sign in to comment.