Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(builder):optimized default unstructured chain invoke #86

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 49 additions & 20 deletions kag/builder/default_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
# or implied.

import logging


from concurrent.futures import ThreadPoolExecutor, as_completed
from kag.interface import (
RecordParserABC,
MappingABC,
Expand Down Expand Up @@ -105,28 +104,58 @@ def __init__(
self.writer = writer

def build(self, **kwargs):
pass

def invoke(self, input_data, max_workers=10, **kwargs):
"""
Builds the builder chain by connecting the parser, splitter, extractor, vectorizer, post-processor (if available), and writer components.
Invokes the builder chain to process the input file.

Args:
file_path: The path to the input file to be processed.
max_workers (int, optional): The maximum number of threads to use. Defaults to 10.
**kwargs: Additional keyword arguments.

Returns:
KAGBuilderChain: The constructed builder chain.
List: The final output from the builder chain.
"""
if self.post_processor:
return (
self.parser
>> self.splitter
>> self.extractor
>> self.vectorizer
>> self.post_processor
>> self.writer
)
return (
self.parser
>> self.splitter
>> self.extractor
>> self.vectorizer
>> self.writer
)

def execute_node(node, node_input):
if not isinstance(node_input, list):
node_input = [node_input]
node_output = []
for item in node_input:
node_output.extend(node.invoke(item))
return node_output

def run_extract(chunk):
flow_data = [chunk]
for node in [
self.extractor,
self.vectorizer,
self.post_processor,
self.writer,
]:
if node is None:
continue
flow_data = execute_node(node, flow_data)
return flow_data

parser_output = self.parser.invoke(input_data)
splitter_output = self.splitter.invoke(parser_output)

result = []
with ThreadPoolExecutor(max_workers) as executor:
futures = [executor.submit(run_extract, chunk) for chunk in splitter_output]

from tqdm import tqdm

for inner_future in tqdm(
as_completed(futures),
total=len(futures),
desc="Chunk Extraction",
position=1,
leave=False,
):
ret = inner_future.result()
result.extend(ret)
return result
16 changes: 10 additions & 6 deletions kag/builder/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def __init__(self, path: str, rank: int = 0, world_size: int = 1):
self.path, CKPT.ckpt_file_name.format(rank, world_size)
)
self._ckpt = set()
if os.path.exists(self.ckpt_file_path):
self.load()
self.load()

def load(self):
"""
Expand All @@ -104,10 +103,15 @@ def load(self):
ckpt_file_path = os.path.join(
self.path, CKPT.ckpt_file_name.format(rank, self.world_size)
)
with open(ckpt_file_path, "r") as reader:
for line in reader:
data = json.loads(line)
self._ckpt.add(data["id"])
if os.path.exists(ckpt_file_path):
with open(ckpt_file_path, "r") as reader:
for line in reader:
data = json.loads(line)
self._ckpt.add(data["id"])
if len(self._ckpt) > 0:
print(
f"{bold}{red}Existing checkpoint found in {self.path}{reset}, with {len(self._ckpt)} processed records."
)

def is_processed(self, data_id: str):
"""
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/builder/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from kag.common.conf import KAG_CONFIG
from kag.builder.runner import CKPT, BuilderChainRunner

# pwd = os.path.dirname(__file__)
pwd = "./"
pwd = os.path.dirname(__file__)


def test_ckpt():
Expand Down
Loading