From a1fbdb62eb707816fcc95db3e25d863bfa4be09a Mon Sep 17 00:00:00 2001 From: zhuzhongshu123 Date: Tue, 19 Nov 2024 19:35:59 +0800 Subject: [PATCH] split reader and parser --- kag/__init__.py | 3 + kag/builder/component/__init__.py | 30 +-- .../component/extractor/.#kag_extractor.py | 1 + kag/builder/component/reader/csv_reader.py | 73 +------ .../component/reader/dataset_reader.py | 72 +++---- .../component/reader/directory_reader.py | 57 ++++++ kag/builder/component/reader/json_reader.py | 74 +------ kag/builder/component/reader/yuque_reader.py | 57 +++--- .../component/record_parser/__init__.py | 0 .../component/record_parser/dict_parser.py | 43 +++++ .../docx_parser.py} | 28 +-- .../markdown_parser.py} | 34 ++-- .../pdf_parser.py} | 20 +- .../txt_parser.py} | 18 +- kag/builder/default_chain.py | 9 +- kag/builder/runner.py | 119 ++++++++++++ kag/common/sharding_info.py | 181 ++++++++++++++++++ kag/common/utils.py | 67 +++---- kag/interface/__init__.py | 2 + kag/interface/builder/builder_chain_abc.py | 49 ++++- kag/interface/builder/reader_abc.py | 32 +++- kag/interface/builder/record_parser_abc.py | 37 ++++ tests/unit/builder/component/test_reader.py | 151 +++++++-------- .../builder/component/test_record_parser.py | 91 +++++++++ tests/unit/builder/test_runner.py | 16 ++ 25 files changed, 840 insertions(+), 424 deletions(-) create mode 120000 kag/builder/component/extractor/.#kag_extractor.py create mode 100644 kag/builder/component/reader/directory_reader.py create mode 100644 kag/builder/component/record_parser/__init__.py create mode 100644 kag/builder/component/record_parser/dict_parser.py rename kag/builder/component/{reader/docx_reader.py => record_parser/docx_parser.py} (89%) rename kag/builder/component/{reader/markdown_reader.py => record_parser/markdown_parser.py} (94%) rename kag/builder/component/{reader/pdf_reader.py => record_parser/pdf_parser.py} (94%) rename kag/builder/component/{reader/txt_reader.py => record_parser/txt_parser.py} (84%) create mode 100644 kag/builder/runner.py create mode 100644 kag/common/sharding_info.py create mode 100644 kag/interface/builder/record_parser_abc.py create mode 100644 tests/unit/builder/component/test_record_parser.py create mode 100644 tests/unit/builder/test_runner.py diff --git a/kag/__init__.py b/kag/__init__.py index 8de9fd50..334562d3 100644 --- a/kag/__init__.py +++ b/kag/__init__.py @@ -211,7 +211,10 @@ init_env() import kag.interface + import kag.builder.component +import kag.builder.default_chain +import kag.builder.runner import kag.builder.prompt import kag.solver.prompt import kag.common.vectorize_model diff --git a/kag/builder/component/__init__.py b/kag/builder/component/__init__.py index 80d211ce..c411ca9b 100644 --- a/kag/builder/component/__init__.py +++ b/kag/builder/component/__init__.py @@ -23,17 +23,22 @@ from kag.builder.component.mapping.relation_mapping import RelationMapping from kag.builder.component.mapping.spo_mapping import SPOMapping from kag.builder.component.reader.csv_reader import CSVReader -from kag.builder.component.reader.pdf_reader import PDFReader from kag.builder.component.reader.json_reader import JSONReader -from kag.builder.component.reader.markdown_reader import MarkDownReader -from kag.builder.component.reader.docx_reader import DocxReader -from kag.builder.component.reader.txt_reader import TXTReader +from kag.builder.component.reader.yuque_reader import YuqueReader from kag.builder.component.reader.dataset_reader import ( - HotpotqaCorpusReader, - TwowikiCorpusReader, MusiqueCorpusReader, + HotpotqaCorpusReader, ) -from kag.builder.component.reader.yuque_reader import YuqueReader +from kag.builder.component.reader.directory_reader import DirectoryReader + + +from kag.builder.component.record_parser.pdf_parser import PDFParser +from kag.builder.component.record_parser.markdown_parser import MarkDownParser +from kag.builder.component.record_parser.docx_parser import DocxParser +from kag.builder.component.record_parser.txt_parser import TXTParser +from kag.builder.component.record_parser.dict_parser import DictParser + + from kag.builder.component.splitter.length_splitter import LengthSplitter from kag.builder.component.splitter.pattern_splitter import PatternSplitter from kag.builder.component.splitter.outline_splitter import OutlineSplitter @@ -53,16 +58,17 @@ "SPGTypeMapping", "RelationMapping", "SPOMapping", - "TXTReader", - "PDFReader", - "MarkDownReader", + "TXTParser", + "PDFParser", + "MarkDownParser", + "DocxParser", + "DictParser", "JSONReader", "HotpotqaCorpusReader", "MusiqueCorpusReader", - "TwowikiCorpusReader", + "DirectoryReader", "YuqueReader", "CSVReader", - "DocxReader", "LengthSplitter", "PatternSplitter", "OutlineSplitter", diff --git a/kag/builder/component/extractor/.#kag_extractor.py b/kag/builder/component/extractor/.#kag_extractor.py new file mode 120000 index 00000000..4ee9f497 --- /dev/null +++ b/kag/builder/component/extractor/.#kag_extractor.py @@ -0,0 +1 @@ +simplex@MacBook-Pro.local.69088 \ No newline at end of file diff --git a/kag/builder/component/reader/csv_reader.py b/kag/builder/component/reader/csv_reader.py index 9dbb9ba4..6dbeb162 100644 --- a/kag/builder/component/reader/csv_reader.py +++ b/kag/builder/component/reader/csv_reader.py @@ -9,12 +9,9 @@ # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. -import os -from typing import List, Type, Dict +from typing import Dict, List import pandas as pd - -from kag.builder.model.chunk import Chunk from kag.interface import SourceReaderABC from knext.common.base.runnable import Input, Output @@ -30,70 +27,14 @@ class CSVReader(SourceReaderABC): **kwargs: Additional keyword arguments passed to the parent class constructor. """ - def __init__( - self, - output_type: str = "Chunk", - id_col: str = "id", - name_col: str = "name", - content_col: str = "content", - ): - if output_type.lower().strip() == "dict": - self.output_types = Dict[str, str] - else: - self.output_types = Chunk - self.id_col = id_col - self.name_col = name_col - self.content_col = content_col - @property - def input_types(self) -> Type[Input]: + def input_types(self) -> Input: return str @property - def output_types(self) -> Type[Output]: - return self._output_types - - @output_types.setter - def output_types(self, output_types): - self._output_types = output_types - - def invoke(self, input: Input, **kwargs) -> List[Output]: - """ - Reads a CSV file and converts the data format based on the output type. - - Args: - input (Input): Input parameter, expected to be a string representing the path to the CSV file. - **kwargs: Additional keyword arguments, currently unused but kept for potential future expansion. - Returns: - List[Output]: - - If `output_types` is `Chunk`, returns a list of Chunk objects. - - If `output_types` is `Dict`, returns a list of dictionaries. - """ - - try: - data = pd.read_csv(input) - data = data.astype(str) - except Exception as e: - raise IOError(f"Failed to read the file: {e}") + def output_types(self) -> Output: + return Dict - if self.output_types == Chunk: - chunks = [] - basename, _ = os.path.splitext(os.path.basename(input)) - for idx, row in enumerate(data.to_dict(orient="records")): - kwargs = { - k: v - for k, v in row.items() - if k not in [self.id_col, self.name_col, self.content_col] - } - chunks.append( - Chunk( - id=row.get(self.id_col) - or Chunk.generate_hash_id(f"{input}#{idx}"), - name=row.get(self.name_col) or f"{basename}#{idx}", - content=row[self.content_col], - **kwargs, - ) - ) - return chunks - else: - return data.to_dict(orient="records") + def load_data(self, input: Input, **kwargs) -> List[Output]: + data = pd.read_csv(input) + return data.to_dict(orient="records") diff --git a/kag/builder/component/reader/dataset_reader.py b/kag/builder/component/reader/dataset_reader.py index 6283dfdd..cbde5df7 100644 --- a/kag/builder/component/reader/dataset_reader.py +++ b/kag/builder/component/reader/dataset_reader.py @@ -12,9 +12,9 @@ import json import os -from typing import List, Type +from typing import List, Type, Dict + -from kag.builder.model.chunk import Chunk from kag.interface import SourceReaderABC from knext.common.base.runnable import Input, Output @@ -29,27 +29,25 @@ def input_types(self) -> Type[Input]: @property def output_types(self) -> Type[Output]: """The type of output this Runnable object produces specified as a type annotation.""" - return Chunk + return Dict - def invoke(self, input: str, **kwargs) -> List[Output]: + def load_data(self, input: Input, **kwargs) -> List[Output]: if os.path.exists(str(input)): with open(input, "r") as f: corpus = json.load(f) else: corpus = json.loads(input) - chunks = [] + data = [] for item_key, item_value in corpus.items(): - chunk = Chunk( - id=item_key, - name=item_key, - content="\n".join(item_value), + data.append( + {"id": item_key, "name": item_key, "content": "\n".join(item_value)} ) - chunks.append(chunk) - return chunks + return data @SourceReaderABC.register("musique") +@SourceReaderABC.register("2wiki") class MusiqueCorpusReader(SourceReaderABC): @property def input_types(self) -> Type[Input]: @@ -59,42 +57,26 @@ def input_types(self) -> Type[Input]: @property def output_types(self) -> Type[Output]: """The type of output this Runnable object produces specified as a type annotation.""" - return Chunk + return Dict def get_basename(self, file_name: str): - base, ext = os.path.splitext(os.path.basename(file_name)) + base, _ = os.path.splitext(os.path.basename(file_name)) return base - def invoke(self, input: str, **kwargs) -> List[Output]: - id_column = kwargs.get("id_column", "title") - name_column = kwargs.get("name_column", "title") - content_column = kwargs.get("content_column", "text") - - if os.path.exists(str(input)): - with open(input, "r") as f: - corpusList = json.load(f) - else: - corpusList = input - chunks = [] - - for idx, item in enumerate(corpusList): - chunk = Chunk( - id=f"{item[id_column]}#{idx}", - name=item[name_column], - content=item[content_column], + def load_data(self, input: Input, **kwargs) -> List[Output]: + + with open(input, "r") as f: + corpus = json.load(f) + data = [] + + for idx, item in enumerate(corpus): + title = item["title"] + content = item["text"] + data.append( + { + "id": f"{title}#{idx}", + "name": title, + "content": content, + } ) - chunks.append(chunk) - return chunks - - -@SourceReaderABC.register("2wiki") -class TwowikiCorpusReader(MusiqueCorpusReader): - @property - def input_types(self) -> Type[Input]: - """The type of input this Runnable object accepts specified as a type annotation.""" - return str - - @property - def output_types(self) -> Type[Output]: - """The type of output this Runnable object produces specified as a type annotation.""" - return Chunk + return data diff --git a/kag/builder/component/reader/directory_reader.py b/kag/builder/component/reader/directory_reader.py new file mode 100644 index 00000000..86d8d1ff --- /dev/null +++ b/kag/builder/component/reader/directory_reader.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 OpenSPG Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. +import os +import re +from typing import List + +from kag.interface import SourceReaderABC + +from knext.common.base.runnable import Input, Output + + +@SourceReaderABC.register("dir") +class DirectoryReader(SourceReaderABC): + def __init__( + self, + file_pattern: str = None, + file_suffix: str = None, + rank: int = 0, + world_size: int = 1, + ): + super().__init__(rank, world_size) + + if file_pattern is None: + if file_suffix: + file_pattern = f".*{file_suffix}$" + else: + file_pattern = r".*txt$" + self.file_pattern = re.compile(file_pattern) + + @property + def input_types(self) -> Input: + return str + + @property + def output_types(self) -> Output: + return str + + def find_files_by_regex(self, directory): + matched_files = [] + for root, dirs, files in os.walk(directory): + for file in files: + if self.file_pattern.match(file): + file_path = os.path.join(root, file) + matched_files.append(file_path) + return matched_files + + def load_data(self, input: Input, **kwargs) -> List[Output]: + return self.find_files_by_regex(input) diff --git a/kag/builder/component/reader/json_reader.py b/kag/builder/component/reader/json_reader.py index 28db2a32..06f20bb0 100644 --- a/kag/builder/component/reader/json_reader.py +++ b/kag/builder/component/reader/json_reader.py @@ -12,9 +12,8 @@ import json import os -from typing import List, Type, Dict, Union +from typing import Union, Dict, List -from kag.builder.model.chunk import Chunk from kag.interface import SourceReaderABC from knext.common.base.runnable import Input, Output @@ -22,40 +21,16 @@ @SourceReaderABC.register("json") class JSONReader(SourceReaderABC): """ - A class for reading JSON files, inheriting from `SourceReader`. - Supports converting JSON data into either a list of dictionaries or a list of Chunk objects. - - Args: - output_types (Output): Specifies the output type, which can be "Dict" or "Chunk". - **kwargs: Additional keyword arguments passed to the parent class constructor. + A class for reading JSON files, inheriting from `SourceReaderABC`. """ - def __init__( - self, - output_type: str = "Chunk", - id_col: str = "id", - name_col: str = "name", - content_col: str = "content", - ): - if output_type.lower().strip() == "dict": - self.output_types = Dict[str, str] - else: - self.output_types = Chunk - self.id_col = id_col - self.name_col = name_col - self.content_col = content_col - @property - def input_types(self) -> Type[Input]: + def input_types(self) -> Input: return str @property - def output_types(self) -> Type[Output]: - return self._output_types - - @output_types.setter - def output_types(self, output_types): - self._output_types = output_types + def output_types(self) -> Output: + return Dict @staticmethod def _read_from_file(file_path: str) -> Union[dict, list]: @@ -98,24 +73,7 @@ def _parse_json_string(json_string: str) -> Union[dict, list]: except json.JSONDecodeError as e: raise ValueError(f"Error parsing JSON string: {e}") - def invoke(self, input: str, **kwargs) -> List[Output]: - """ - Parses the input string data and generates a list of Chunk objects or returns the original data. - - This method supports receiving JSON-formatted strings - It can read from a file or directly parse a string. If the input data is in the expected format, it generates a list of Chunk objects; - otherwise, it throws a ValueError if the input is not a JSON array or object. - - Args: - input (str): The input data, which can be a JSON string or a file path. - **kwargs: Additional keyword arguments, currently unused but kept for potential future expansion. - Returns: - List[Output]: A list of Chunk objects or the original data. - - Raises: - ValueError: If the input data format is incorrect or parsing fails. - """ - + def load_data(self, input: Input, **kwargs) -> List[Output]: try: if os.path.exists(input): corpus = self._read_from_file(input) @@ -129,22 +87,4 @@ def invoke(self, input: str, **kwargs) -> List[Output]: if isinstance(corpus, dict): corpus = [corpus] - - if self.output_types == Chunk: - chunks = [] - basename, _ = os.path.splitext(os.path.basename(input)) - for idx, item in enumerate(corpus): - if not isinstance(item, dict): - continue - - chunk = Chunk( - id=item.get(self.id_col) - or Chunk.generate_hash_id(f"{input}#{idx}"), - name=item.get(self.name_col) or f"{basename}#{idx}", - content=item.get(self.content_col), - ) - chunks.append(chunk) - - return chunks - else: - return corpus + return corpus diff --git a/kag/builder/component/reader/yuque_reader.py b/kag/builder/component/reader/yuque_reader.py index 21bac49e..cb7f28ef 100644 --- a/kag/builder/component/reader/yuque_reader.py +++ b/kag/builder/component/reader/yuque_reader.py @@ -9,22 +9,20 @@ # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. - +import os import requests from typing import Type, List -from kag.builder.component.reader.markdown_reader import MarkDownReader -from kag.builder.model.chunk import Chunk +# from kag.builder.component.reader.markdown_reader import MarkDownReader from kag.interface import SourceReaderABC -from kag.interface import LLMClient from knext.common.base.runnable import Input, Output @SourceReaderABC.register("yuque") class YuqueReader(SourceReaderABC): - def __init__(self, token: str, llm: LLMClient = None, cut_depth: int = 1): + def __init__(self, token: str, rank: int = 0, world_size: int = 1): + super().__init__(rank, world_size) self.token = token - self.markdown_reader = MarkDownReader(llm, cut_depth) @property def input_types(self) -> Type[Input]: @@ -34,33 +32,22 @@ def input_types(self) -> Type[Input]: @property def output_types(self) -> Type[Output]: """The type of output this Runnable object produces specified as a type annotation.""" - return Chunk - - @staticmethod - def get_yuque_api_data(token, url): - headers = {"X-Auth-Token": token} - - try: - response = requests.get(url, headers=headers) - response.raise_for_status() # Raise an HTTPError for bad responses (4xx and 5xx) - return response.json()["data"] # Assuming the API returns JSON data - except requests.exceptions.HTTPError as http_err: - print(f"HTTP error occurred: {http_err}") - except requests.exceptions.RequestException as err: - print(f"Error occurred: {err}") - except Exception as err: - print(f"An error occurred: {err}") - - def invoke(self, input: str, **kwargs) -> List[Output]: - if not input: - raise ValueError("Input cannot be empty") - - url: str = input - data = self.get_yuque_api_data(self.token, url) - id = data.get("id", "") - title = data.get("title", "") - content = data.get("body", "") - - chunks = self.markdown_reader.solve_content(id, title, content) + return str - return chunks + def get_yuque_api_data(self, url): + headers = {"X-Auth-Token": self.token} + response = requests.get(url, headers=headers) + response.raise_for_status() # Raise an HTTPError for bad responses (4xx and 5xx) + return response.json()["data"] # Assuming the API returns JSON data + + def load_data(self, input: Input, **kwargs) -> List[Output]: + url = input + data = self.get_yuque_api_data(url) + if isinstance(data, dict): + # for single yuque doc + return [f"{self.token}@{url}"] + output = [] + for item in data: + slug = item["slug"] + output.append(os.path.join(url, slug)) + return [f"{self.token}@{url}" for url in output] diff --git a/kag/builder/component/record_parser/__init__.py b/kag/builder/component/record_parser/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kag/builder/component/record_parser/dict_parser.py b/kag/builder/component/record_parser/dict_parser.py new file mode 100644 index 00000000..face2a81 --- /dev/null +++ b/kag/builder/component/record_parser/dict_parser.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 OpenSPG Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. + +from typing import Dict, List +from kag.interface import RecordParserABC +from knext.common.base.runnable import Output, Input +from kag.builder.model.chunk import Chunk + + +@RecordParserABC.register("dict") +class DictParser(RecordParserABC): + """ + A class for convert dict object to chunks, inheriting from `RecordParserABC`. + + Args: + cut_depth (int): The depth of cutting, determining the level of detail in parsing. Default is 1. + """ + + def __init__( + self, id_col: str = "id", name_col: str = "name", content_col: str = "content" + ): + self.id_col = id_col + self.name_col = name_col + self.content_col = content_col + + @property + def input_types(self) -> Input: + return Dict + + def invoke(self, input: Input, **kwargs) -> List[Output]: + chunk_id = input.pop(self.id_col) + chunk_name = input.pop(self.name_col) + chunk_content = input.pop(self.content_col) + return [Chunk(id=chunk_id, name=chunk_name, content=chunk_content, **input)] diff --git a/kag/builder/component/reader/docx_reader.py b/kag/builder/component/record_parser/docx_parser.py similarity index 89% rename from kag/builder/component/reader/docx_reader.py rename to kag/builder/component/record_parser/docx_parser.py index 460741fd..519e193a 100644 --- a/kag/builder/component/reader/docx_reader.py +++ b/kag/builder/component/record_parser/docx_parser.py @@ -11,12 +11,12 @@ # or implied. import os -from typing import List, Type, Union +from typing import List, Union from docx import Document from kag.interface import LLMClient from kag.builder.model.chunk import Chunk -from kag.interface import SourceReaderABC +from kag.interface import RecordParserABC from kag.builder.prompt.outline_prompt import OutlinePrompt from kag.common.conf import KAG_PROJECT_CONF @@ -41,10 +41,10 @@ def split_txt(content): return res -@SourceReaderABC.register("docx") -class DocxReader(SourceReaderABC): +@RecordParserABC.register("docx") +class DocxParser(RecordParserABC): """ - A class for reading Docx files, inheriting from SourceReader. + A class for reading Docx files, inheriting from RecordParserABC. This class is specifically designed to extract text content from Docx files and generate Chunk objects based on the extracted content. """ @@ -52,14 +52,6 @@ def __init__(self, llm: LLMClient = None): self.llm = llm self.prompt = OutlinePrompt(KAG_PROJECT_CONF.language) - @property - def input_types(self) -> Type[Input]: - return str - - @property - def output_types(self) -> Type[Output]: - return Chunk - def outline_chunk(self, chunk: Union[Chunk, List[Chunk]], basename) -> List[Chunk]: if isinstance(chunk, Chunk): chunk = [chunk] @@ -173,13 +165,3 @@ def invoke(self, input: Input, **kwargs) -> List[Output]: ] return chunks - - -if __name__ == "__main__": - reader = DocxReader() - print(reader.output_types) - file_path = os.path.dirname(__file__) - res = reader.invoke( - os.path.join(file_path, "../../../../tests/builder/data/test_docx.docx") - ) - print(res) diff --git a/kag/builder/component/reader/markdown_reader.py b/kag/builder/component/record_parser/markdown_parser.py similarity index 94% rename from kag/builder/component/reader/markdown_reader.py rename to kag/builder/component/record_parser/markdown_parser.py index 209148ee..272f3b1c 100644 --- a/kag/builder/component/reader/markdown_reader.py +++ b/kag/builder/component/record_parser/markdown_parser.py @@ -15,7 +15,7 @@ import bs4.element import markdown from bs4 import BeautifulSoup, Tag -from typing import List, Type +from typing import List import logging import re import requests @@ -23,7 +23,7 @@ from io import StringIO from tenacity import stop_after_attempt, retry -from kag.interface import SourceReaderABC +from kag.interface import RecordParserABC from kag.builder.model.chunk import Chunk, ChunkTypeEnum from kag.interface import LLMClient from kag.common.conf import KAG_PROJECT_CONF @@ -34,10 +34,10 @@ logger = logging.getLogger(__name__) -@SourceReaderABC.register("md") -class MarkDownReader(SourceReaderABC): +@RecordParserABC.register("md") +class MarkDownParser(RecordParserABC): """ - A class for reading MarkDown files, inheriting from `SourceReader`. + A class for reading MarkDown files, inheriting from `RecordParserABC`. Supports converting MarkDown data into a list of Chunk objects. Args: @@ -54,14 +54,6 @@ def __init__(self, llm: LLMClient = None, cut_depth: int = 1): language=KAG_PROJECT_CONF.language ) - @property - def input_types(self) -> Type[Input]: - return str - - @property - def output_types(self) -> Type[Output]: - return Chunk - def to_text(self, level_tags): """ Converts parsed hierarchical tags into text content. @@ -424,3 +416,19 @@ def invoke(self, input: Input, **kwargs) -> List[Output]: chunks = self.solve_content(input, basename, content) return chunks + + +@RecordParserABC.register("yuque") +class YuequeParser(MarkDownParser): + def invoke(self, input: Input, **kwargs) -> List[Output]: + token, url = input.split("@", 1) + headers = {"X-Auth-Token": token} + response = requests.get(url, headers=headers) + response.raise_for_status() # Raise an HTTPError for bad responses (4xx and 5xx) + data = response.json()["data"] + id = data.get("id", "") + title = data.get("title", "") + content = data.get("body", "") + + chunks = self.solve_content(id, title, content) + return chunks diff --git a/kag/builder/component/reader/pdf_reader.py b/kag/builder/component/record_parser/pdf_parser.py similarity index 94% rename from kag/builder/component/reader/pdf_reader.py rename to kag/builder/component/record_parser/pdf_parser.py index 4e935f26..f8851391 100644 --- a/kag/builder/component/reader/pdf_reader.py +++ b/kag/builder/component/record_parser/pdf_parser.py @@ -12,13 +12,13 @@ import os import re -from typing import List, Sequence, Type, Union +from typing import List, Sequence, Union import pdfminer.layout # noqa from kag.builder.model.chunk import Chunk -from kag.interface import SourceReaderABC +from kag.interface import RecordParserABC from kag.builder.prompt.outline_prompt import OutlinePrompt from kag.interface import LLMClient @@ -37,10 +37,10 @@ logger = logging.getLogger(__name__) -@SourceReaderABC.register("pdf") -class PDFReader(SourceReaderABC): +@RecordParserABC.register("pdf") +class PDFFileParser(RecordParserABC): """ - A PDF reader class that inherits from SourceReader. + A PDF reader class that inherits from RecordParserABC. """ def __init__(self, llm: LLMClient = None, split_level: int = 3): @@ -51,14 +51,6 @@ def __init__(self, llm: LLMClient = None, split_level: int = 3): self.llm = llm self.prompt = OutlinePrompt(KAG_PROJECT_CONF.language) - @property - def input_types(self) -> Type[Input]: - return str - - @property - def output_types(self) -> Type[Output]: - return Chunk - def outline_chunk(self, chunk: Union[Chunk, List[Chunk]], basename) -> List[Chunk]: if isinstance(chunk, Chunk): chunk = [chunk] @@ -142,7 +134,7 @@ def _extract_text_from_page(page_layout: LTPage) -> str: text += element.get_text() return text - def invoke(self, input: str, **kwargs) -> Sequence[Output]: + def invoke(self, input: Input, **kwargs) -> Sequence[Output]: """ Processes a PDF file, splitting or extracting content based on configuration. diff --git a/kag/builder/component/reader/txt_reader.py b/kag/builder/component/record_parser/txt_parser.py similarity index 84% rename from kag/builder/component/reader/txt_reader.py rename to kag/builder/component/record_parser/txt_parser.py index f84b0657..ca0e3782 100644 --- a/kag/builder/component/reader/txt_reader.py +++ b/kag/builder/component/record_parser/txt_parser.py @@ -11,27 +11,19 @@ # or implied. import os -from typing import List, Type +from typing import List from kag.builder.model.chunk import Chunk -from kag.interface import SourceReaderABC +from kag.interface import RecordParserABC from knext.common.base.runnable import Input, Output -@SourceReaderABC.register("txt") -class TXTReader(SourceReaderABC): +@RecordParserABC.register("txt") +class TXTParser(RecordParserABC): """ - A PDF reader class that inherits from SourceReader. + A txt paraser class that inherits from RecordParserABC. """ - @property - def input_types(self) -> Type[Input]: - return str - - @property - def output_types(self) -> Type[Output]: - return Chunk - def invoke(self, input: Input, **kwargs) -> List[Output]: """ The main method for processing text reading. This method reads the content of the input (which can be a file path or text content) and converts it into a Chunk object. diff --git a/kag/builder/default_chain.py b/kag/builder/default_chain.py index ab06cda8..37cdacd2 100644 --- a/kag/builder/default_chain.py +++ b/kag/builder/default_chain.py @@ -20,10 +20,11 @@ from kag.interface import ( + SourceReaderABC, ExtractorABC, SplitterABC, VectorizerABC, - SourceReaderABC, + RecordParserABC, PostProcessorABC, SinkWriterABC, KAGBuilderChain, @@ -97,14 +98,14 @@ def invoke(self, file_path, max_workers=10, **kwargs): class DefaultUnstructuredBuilderChain(KAGBuilderChain): def __init__( self, - reader: SourceReaderABC, + parser: RecordParserABC, splitter: SplitterABC, extractor: ExtractorABC, vectorizer: VectorizerABC, post_processor: PostProcessorABC, writer: SinkWriterABC, ): - self.reader = reader + self.parser = parser self.splitter = splitter self.extractor = extractor self.vectorizer = vectorizer @@ -113,7 +114,7 @@ def __init__( def build(self, **kwargs): return ( - self.reader + self.parser >> self.splitter >> self.extractor >> self.vectorizer diff --git a/kag/builder/runner.py b/kag/builder/runner.py new file mode 100644 index 00000000..9ac15451 --- /dev/null +++ b/kag/builder/runner.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 OpenSPG Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. + + +import hashlib +import os +import json +import traceback +from datetime import datetime +from tqdm import tqdm +from kag.common.registry import Registrable +from kag.interface import KAGBuilderChain, SourceReaderABC +from concurrent.futures import ThreadPoolExecutor, as_completed + + +def generate_hash_id(value): + if isinstance(value, dict): + sorted_items = sorted(value.items()) + value = str(sorted_items) + if isinstance(value, str): + value = value.encode("utf-8") + hasher = hashlib.sha256() + hasher.update(value) + return hasher.hexdigest() + + +class CKPT: + ckpt_file_name = "kag-runner.ckpt" + + def __init__(self, path: str): + self.path = path + self.ckpt_file_path = os.path.join(self.path, CKPT.ckpt_file_name) + self._ckpt = set() + if os.path.exists(self.ckpt_file_path): + self.load() + + def load(self): + with open(self.ckpt_file_path, "r") as reader: + for line in reader: + data = json.loads(line) + self._ckpt.add(data["id"]) + + def is_processed(self, data_id: str): + return data_id in self._ckpt + + def open(self): + self.writer = open(self.ckpt_file_path, "a") + + def add(self, data_id: str): + if self.is_processed(data_id): + return + now = datetime.now() + self.writer.write(json.dumps({"id": data_id, "time": str(now)})) + self.writer.write("\n") + self.writer.flush() + + def close(self): + self.writer.flush() + self.writer.close() + + +class BuilderChainRunner(Registrable): + def __init__( + self, + reader: SourceReaderABC, + chain: KAGBuilderChain, + num_parallel: int = 4, + ckpt_dir: str = None, + ): + self.reader = reader + self.chain = chain + self.num_parallel = num_parallel + if ckpt_dir is None: + ckpt_dir = "./ckpt" + self.ckpt_dir = ckpt_dir + if not os.path.exists(self.ckpt_dir): + os.makedirs(self.ckpt_dir, exist_ok=True) + + self.ckpt = CKPT(self.ckpt_dir) + print(self.ckpt._ckpt) + + def invoke(self, input): + def process(chain, data, data_id): + try: + result = chain.invoke(data) + return result, data_id + except Exception: + traceback.print_exc() + return None + + self.ckpt.open() + futures = [] + with ThreadPoolExecutor(self.num_parallel) as executor: + for item in self.reader.invoke(input): + item_id = generate_hash_id(item) + if self.ckpt.is_processed(item_id): + continue + fut = executor.submit(process, self.chain, item, item_id) + futures.append(fut) + for future in tqdm( + as_completed(futures), total=len(futures), desc="Processing" + ): + result = future.result() + if result is not None: + chain_output, item_id = result + self.ckpt.add(item_id) + self.ckpt.close() + + +BuilderChainRunner.register("base", as_default=True)(BuilderChainRunner) diff --git a/kag/common/sharding_info.py b/kag/common/sharding_info.py new file mode 100644 index 00000000..e68198fd --- /dev/null +++ b/kag/common/sharding_info.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +from kag.common.registry import Registrable + + +class ShardingInfo(Registrable): + """ + ShardingInfo is used to record sharding-related information. Each machine can contain multiple instances, + and each instance can contain multiple processes. The rank and world_size can then be calculated accordingly. + When shard_id and shard_count are explicitly given, they are directly used as rank and world_size, + mainly used to obtain tasks from cache server. + """ + + def __init__( + self, + machine_id: int = 0, + machine_count: int = 1, + instance_id: int = 0, + instance_count: int = 1, + process_id: int = 0, + process_count: int = 1, + shard_id: int = None, + shard_count: int = None, + ): + self.instance_id = instance_id + self.instance_count = instance_count + self.machine_id = machine_id + self.machine_count = machine_count + self.process_id = process_id + self.process_count = process_count + self.shard_id = shard_id + self.shard_count = shard_count + + self.shard_by_machine = True + self.shard_by_instance = True + self.shard_by_process = True + + def shard_by( + self, machine: bool = True, instance: bool = True, process: bool = True + ): + self.shard_by_machine = machine + self.shard_by_instance = instance + self.shard_by_process = process + + def get_rank(self): + if self.shard_id is not None: + return self.shard_id + if self.shard_by_machine: + machine_id = self.machine_id + else: + machine_id = 0 + if self.shard_by_instance: + instance_id, instance_count = self.instance_id, self.instance_count + else: + instance_id, instance_count = 0, 1 + if self.shard_by_process: + process_id, process_count = self.process_id, self.process_count + else: + process_id, process_count = 0, 1 + + return process_count * (machine_id * instance_count + instance_id) + process_id + + def get_world_size(self): + if self.shard_count is not None: + return self.shard_count + world_size = 1 + if self.shard_by_machine: + world_size *= self.machine_count + if self.shard_by_instance: + world_size *= self.instance_count + if self.shard_by_process: + world_size *= self.process_count + return world_size + + def get_sharding_range(self, total: int): + rank = self.get_rank() + world_size = self.get_world_size() + if total % world_size == 0: + workload = total // world_size + else: + workload = total // world_size + 1 + start = workload * rank + end = min(total, workload * (rank + 1)) + return start, end + + @property + def is_master_process(self): + return self.process_id == 0 + + @property + def is_master_instance(self): + return self.instance_id == 0 + + @property + def is_master_machine(self): + return self.machine_id == 0 + + def __str__(self): + content = ( + f"ShardingInfo: rank={self.get_rank()}, world_size={self.get_world_size()}, " + f"machine: {self.machine_id}/{self.machine_count}, " + f"instance: {self.instance_id}/{self.instance_count}, " + f"process: {self.process_id}/{self.process_count}" + ) + return content + + __repr__ = __str__ + + def copy(self): + return ShardingInfo( + self.machine_id, + self.machine_count, + self.instance_id, + self.instance_count, + self.process_id, + self.process_count, + self.shard_id, + self.shard_count, + ) + + +def partition_based_sharding(num_partitions: int, sharding_info: ShardingInfo): + + """ + The layerwise inference mode requires a special sharding strategy for seed generation and inference + return export, i.e. partition based sharding. + In general cases, each partition divides its seeds according to the total number of + workers(=machines*instances*proceeese) directly. For example, when machine_count=2 and num_partitions=4, + the division in each partition is as follows: + + [[machine 0 | machine 1], [machine 0 | machine 1], [machine 0 | machine 1], [machine 0 | machine 1]] + + However, in the layerwise inference mode, we first assign the partitions to different machine groups, + and then divide the seeds in the respective machine groups to retain the locality. + + if machine_count > num_partitions: + each machine group contains multiple machines and processes one partitin together. + else: + each machine group contains one machine that needs to process one or more partitions. + Therefore, we need to recompute the sharding_info according to the machine group, here are some examples: + + machine_count=2, num_partitions=1 + ==> [[machine 0, machine 1]] + ==> machine_id = 0/1, machine_count = 2 + + machine_count=2, num_partitions=2 + ==> [[machine 0], [machine 1]] + ==> machine_id = 0, machine_count = 1 + + machine_count=2, num_partitions=4 + ==> [[machine 0], [machine 0], [machine 1], [machine 1]] + ==> machine_id = 0, machine_count = 1 + """ + + sharding_info = sharding_info.copy() + machine_id = sharding_info.machine_id + machine_count = sharding_info.machine_count + + if machine_count <= num_partitions: + if num_partitions % machine_count != 0: + msg = f"num_machines {machine_count} can't be divisible by num_partitions {num_partitions}" + raise ValueError(msg) + num_partitions_per_machine = num_partitions // machine_count + responsible_partitions = [ + machine_id * num_partitions_per_machine + x + for x in range(num_partitions_per_machine) + ] + sharding_info.machine_id = 0 + sharding_info.machine_count = 1 + return sharding_info, responsible_partitions + else: + if machine_count % num_partitions != 0: + msg = f"num_partitions {num_partitions} can't be divisible by num_machines {machine_count}" + raise ValueError(msg) + num_machine_per_partition = machine_count // num_partitions + responsible_partitions = [machine_id // num_machine_per_partition] + sharding_info.machine_id = sharding_info.machine_id % num_machine_per_partition + sharding_info.machine_count = num_machine_per_partition + return sharding_info, responsible_partitions + + +ShardingInfo.register("base")(ShardingInfo) diff --git a/kag/common/utils.py b/kag/common/utils.py index 87ff1d39..29954a83 100644 --- a/kag/common/utils.py +++ b/kag/common/utils.py @@ -12,8 +12,7 @@ import re import sys import json -from typing import Type, Tuple -import inspect +from typing import Tuple import os from pathlib import Path import importlib @@ -23,42 +22,6 @@ from stat import S_IWUSR as OWNER_WRITE_PERMISSION -def _register(root, path, files, class_type): - relative_path = os.path.relpath(path, root) - module_prefix = relative_path.replace(".", "").replace("/", ".") - module_prefix = module_prefix + "." if module_prefix else "" - for file_name in files: - if file_name.endswith(".py"): - module_name = module_prefix + os.path.splitext(file_name)[0] - import importlib - - module = importlib.import_module(module_name) - classes = inspect.getmembers(module, inspect.isclass) - for class_name, class_obj in classes: - if ( - issubclass(class_obj, class_type) - and inspect.getmodule(class_obj) == module - ): - - class_type.register( - name=class_name, - local_path=os.path.join(path, file_name), - module_path=module_name, - )(class_obj) - - -def register_from_package(path: str, class_type: Type) -> None: - """ - Register all classes under the given package. - Only registered classes can be recognized by kag. - """ - if not append_python_path(path): - return - for root, dirs, files in os.walk(path): - _register(path, root, files, class_type) - class_type._has_registered = True - - def append_python_path(path: str) -> bool: """ Append the given path to `sys.path`. @@ -208,3 +171,31 @@ def get_vector_field_name(property_key: str): name = f"{property_key}_vector" name = to_snake_case(name) return "_" + name + + +def split_list_into_n_parts(lst, n): + length = len(lst) + part_size = length // n + seg = [x * part_size for x in range(n)] + seg.append(min(length, part_size * n)) + + remainder = length % n + + result = [] + + # 分割列表 + start = 0 + for i in range(n): + # 计算当前份的元素数量 + if i < remainder: + end = start + part_size + 1 + else: + end = start + part_size + + # 添加当前份到结果列表 + result.append(lst[start:end]) + + # 更新起始位置 + start = end + + return result diff --git a/kag/interface/__init__.py b/kag/interface/__init__.py index 79f77635..648e7b35 100644 --- a/kag/interface/__init__.py +++ b/kag/interface/__init__.py @@ -13,6 +13,7 @@ from kag.interface.common.llm_client import LLMClient from kag.interface.common.vectorize_model import VectorizeModelABC, EmbeddingVector +from kag.interface.builder.record_parser_abc import RecordParserABC from kag.interface.builder.reader_abc import SourceReaderABC from kag.interface.builder.splitter_abc import SplitterABC from kag.interface.builder.extractor_abc import ExtractorABC @@ -40,6 +41,7 @@ "LLMClient", "VectorizeModelABC", "EmbeddingVector", + "RecordParserABC", "SourceReaderABC", "SplitterABC", "ExtractorABC", diff --git a/kag/interface/builder/builder_chain_abc.py b/kag/interface/builder/builder_chain_abc.py index 69584c75..b5879353 100644 --- a/kag/interface/builder/builder_chain_abc.py +++ b/kag/interface/builder/builder_chain_abc.py @@ -1,7 +1,54 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 OpenSPG Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. + +from typing import List from kag.common.registry import Registrable from knext.builder.builder_chain_abc import BuilderChainABC class KAGBuilderChain(BuilderChainABC, Registrable): - pass + def invoke(self, file_path, **kwargs): + def execute_node(node, inputs: List[str]): + result = [] + for item in inputs: + res = node.invoke(item) + result.extend(res) + return result + + chain = self.build(file_path=file_path, **kwargs) + dag = chain.dag + import networkx as nx + + nodes = list(nx.topological_sort(dag)) + node_outputs = {} + processed_node_names = [] + for node in nodes: + node_name = type(node).__name__.split(".")[-1] + processed_node_names.append(node_name) + predecessors = list(dag.predecessors(node)) + if len(predecessors) == 0: + node_input = [file_path] + node_output = execute_node(node, node_input) + else: + node_input = [] + for p in predecessors: + node_input.extend(node_outputs[p]) + node_output = execute_node(node, node_input) + node_outputs[node] = node_output + output_nodes = [node for node in nodes if dag.out_degree(node) == 0] + final_output = [] + for node in output_nodes: + if node in node_outputs: + final_output.extend(node_outputs[node]) + + return final_output diff --git a/kag/interface/builder/reader_abc.py b/kag/interface/builder/reader_abc.py index bff25e81..76b0f37a 100644 --- a/kag/interface/builder/reader_abc.py +++ b/kag/interface/builder/reader_abc.py @@ -10,10 +10,9 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. from abc import ABC, abstractmethod -from typing import List, Dict, Union - +from typing import Any, Generator, List from kag.interface.builder.base import BuilderComponent -from kag.builder.model.chunk import Chunk +from kag.common.sharding_info import ShardingInfo from knext.common.base.runnable import Input, Output @@ -22,16 +21,35 @@ class SourceReaderABC(BuilderComponent, ABC): Interface for reading files into a list of unstructured chunks or structured dicts. """ + def __init__(self, rank: int = None, world_size: int = None): + if rank is None or world_size is None: + from kag.common.env import get_rank, get_world_size + + rank = get_rank(0) + world_size = get_world_size(1) + self.sharding_info = ShardingInfo(shard_id=rank, shard_count=world_size) + @property def input_types(self) -> Input: return str @property def output_types(self) -> Output: - return Union[Chunk, Dict] + return Any @abstractmethod + def load_data(self, input: Input, **kwargs) -> List[Output]: + raise NotImplementedError("load not implemented yet.") + + def _generate(self, data): + start, end = self.sharding_info.get_sharding_range(len(data)) + for item in data[start:end]: + yield item + + def generate(self, input: Input, **kwargs) -> Generator[Output, Input, None]: + data = self.load_data(input, **kwargs) + for item in self._generate(data): + yield item + def invoke(self, input: Input, **kwargs) -> List[Output]: - raise NotImplementedError( - f"`invoke` is not currently supported for {self.__class__.__name__}." - ) + return list(self.generate(input, **kwargs)) diff --git a/kag/interface/builder/record_parser_abc.py b/kag/interface/builder/record_parser_abc.py new file mode 100644 index 00000000..74766284 --- /dev/null +++ b/kag/interface/builder/record_parser_abc.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 OpenSPG Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. +from abc import ABC, abstractmethod +from typing import List + +from kag.interface.builder.base import BuilderComponent +from kag.builder.model.chunk import Chunk +from knext.common.base.runnable import Input, Output + + +class RecordParserABC(BuilderComponent, ABC): + """ + Interface for reading files into a list of unstructured chunks or structured dicts. + """ + + @property + def input_types(self) -> Input: + return str + + @property + def output_types(self) -> Output: + return Chunk + + @abstractmethod + def invoke(self, input: Input, **kwargs) -> List[Output]: + raise NotImplementedError( + f"`invoke` is not currently supported for {self.__class__.__name__}." + ) diff --git a/tests/unit/builder/component/test_reader.py b/tests/unit/builder/component/test_reader.py index 21fcb652..f1fbe51d 100644 --- a/tests/unit/builder/component/test_reader.py +++ b/tests/unit/builder/component/test_reader.py @@ -6,115 +6,94 @@ from kag.interface import SourceReaderABC -from unittest.mock import patch, mock_open, MagicMock from kag.builder.model.chunk import Chunk, ChunkTypeEnum pwd = os.path.dirname(__file__) -def test_text_reader(): - reader = SourceReaderABC.from_config({"type": "txt"}) - text = "您好!" - chunks = reader.invoke(text) - assert len(chunks) == 1 and chunks[0].content == text - - file_path = os.path.join(pwd, "../data/test_txt.txt") - chunks = reader.invoke(file_path) - with open(file_path) as f: - content = f.read() - chunks = reader.invoke(file_path) - assert len(chunks) == 1 - assert chunks[0].content == content - assert chunks[0].id == Chunk.generate_hash_id(file_path) - - -def test_docx_reader(): - reader = SourceReaderABC.from_config({"type": "docx"}) - - file_path = os.path.join(pwd, "../data/test_docx.docx") - chunks = reader.invoke(file_path) - # Assert the expected result - assert len(chunks) == 30 - assert len(chunks[0].content) > 0 - - def test_json_reader(): - reader = SourceReaderABC.from_config( - {"type": "json", "name_col": "title", "content_col": "text"} - ) + reader = SourceReaderABC.from_config({"type": "json", "rank": 0, "world_size": 1}) file_path = os.path.join(pwd, "../data/test_json.json") with open(file_path, "r") as r: json_string = r.read() json_content = json.loads(json_string) - # read from json file - chunks = reader.invoke(file_path) - assert len(chunks) == len(json_content) - for chunk, json_item in zip(chunks, json_content): - assert chunk.content == json_item["text"] - assert chunk.name == json_item["title"] - # read from json string directly - chunks = reader.invoke(json_string) - assert len(chunks) == len(json_content) - for chunk, json_item in zip(chunks, json_content): - assert chunk.content == json_item["text"] - assert chunk.name == json_item["title"] + data = reader.invoke(file_path) + assert len(data) == len(json_content) + for l, r in zip(data, json_content): + assert l == r + + reader_1 = SourceReaderABC.from_config({"type": "json", "rank": 0, "world_size": 2}) + reader_2 = SourceReaderABC.from_config({"type": "json", "rank": 1, "world_size": 2}) + + data_1 = reader_1.invoke(file_path) + data_2 = reader_2.invoke(file_path) + data = data_1 + data_2 + assert len(data) == len(json_content) + for l, r in zip(data, json_content): + assert l == r def test_csv_reader(): - reader = SourceReaderABC.from_config( - {"type": "csv", "id_col": "idx", "name_col": "title", "content_col": "text"} - ) + reader = SourceReaderABC.from_config({"type": "csv", "rank": 0, "world_size": 1}) file_path = os.path.join(pwd, "../data/test_csv.csv") - chunks = reader.invoke(file_path) - - data = pd.read_csv(file_path) - assert len(chunks) == len(data) - for idx in range(len(chunks)): - chunk = chunks[idx] - row = data.iloc[idx] - assert str(chunk.id) == str(row.idx) - assert chunk.name == row.title - assert chunk.content == row.text - - -def test_md_reader(): - reader = SourceReaderABC.from_config({"type": "md", "cut_depth": 1}) - file_path = os.path.join(pwd, "../data/test_markdown.md") - chunks = reader.invoke(file_path) - assert len(chunks) > 0 - assert chunks[0].name == "test_markdown#0" - - -def test_pdf_reader(): - reader = SourceReaderABC.from_config({"type": "pdf"}) - - page = "Header\nContent 1\nContent 2\nFooter" - watermark = "Header" - expected = ["Content 1", "Content 2"] - result = reader._process_single_page( - page, watermark, remove_header=True, remove_footnote=True + csv_content = [] + for _, item in pd.read_csv(file_path).iterrows(): + csv_content.append(item.to_dict()) + data = reader.invoke(file_path) + + assert len(data) == len(csv_content) + for l, r in zip(data, csv_content): + assert l == r + + reader_1 = SourceReaderABC.from_config({"type": "csv", "rank": 0, "world_size": 2}) + reader_2 = SourceReaderABC.from_config({"type": "csv", "rank": 1, "world_size": 2}) + + data_1 = reader_1.invoke(file_path) + data_2 = reader_2.invoke(file_path) + data = data_1 + data_2 + assert len(data) == len(csv_content) + for l, r in zip(data, csv_content): + assert l == r + + +def test_directory_reader(): + reader = SourceReaderABC.from_config({"type": "dir", "file_suffix": "json"}) + dir_path = os.path.join(pwd, "../data/") + all_data = reader.invoke(dir_path) + for item in all_data: + assert os.path.exists(item) + assert item.endswith("json") + + reader_1 = SourceReaderABC.from_config( + {"type": "dir", "file_suffix": "json", "rank": 0, "world_size": 2} + ) + reader_2 = SourceReaderABC.from_config( + {"type": "dir", "file_suffix": "json", "rank": 1, "world_size": 2} ) - assert result == expected - file_path = os.path.join(pwd, "../data/test_pdf.pdf") - chunks = reader.invoke(file_path) - assert chunks[0].name == "test_pdf#0" + data_1 = reader_1.invoke(dir_path) + data_2 = reader_2.invoke(dir_path) + assert len(all_data) == len(data_1) + len(data_2) + + reader = SourceReaderABC.from_config({"type": "dir", "file_pattern": ".*txt$"}) + all_data = reader.invoke(dir_path) + + for item in all_data: + assert os.path.exists(item) + assert item.endswith("txt") def test_yuque_reader(): reader = SourceReaderABC.from_config( { "type": "yuque", - "token": "1yPz1LbE20FmXvemCDVwjlSHpAp18qtEu7wcjCfv", - "cut_depth": 1, + "token": "f6QiFu1gIDEGJIsI6jziOWbE7E9MsFkipeV69NHq", } ) - from kag.builder.component import MarkDownReader - - assert isinstance(reader.markdown_reader, MarkDownReader) - - chunks = reader.invoke( - "https://yuque-api.antfin-inc.com/api/v2/repos/ob46m2/it70c2/docs/bnp80qitsy5vqoa5" + urls = reader.invoke( + "https://yuque-api.antfin-inc.com/api/v2/repos/un8gkl/kg7h1z/docs/" ) - assert chunks[0].content[:6] == "1、建设目标" + for url in urls: + token, rea_url = url.split("@", 1) + assert token == reader.token diff --git a/tests/unit/builder/component/test_record_parser.py b/tests/unit/builder/component/test_record_parser.py new file mode 100644 index 00000000..66196a84 --- /dev/null +++ b/tests/unit/builder/component/test_record_parser.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +import os + +import copy +from kag.interface import RecordParserABC +from kag.builder.model.chunk import Chunk + +pwd = os.path.dirname(__file__) + + +def test_dict_parser(): + parser = RecordParserABC.from_config( + { + "type": "dict", + "id_col": "data_id", + "name_col": "data_name", + "content_col": "data_content", + } + ) + content = { + "data_id": "111", + "data_name": "222", + "data_content": "hello.", + "extra": "Nice.", + } + chunks = parser.invoke(copy.deepcopy(content)) + assert len(chunks) == 1 + assert isinstance(chunks[0], Chunk) + chunk = chunks[0] + assert chunk.id == content["data_id"] + assert chunk.name == content["data_name"] + assert chunk.content == content["data_content"] + assert chunk.kwargs["extra"] == content["extra"] + + +def test_text_parser(): + parser = RecordParserABC.from_config({"type": "txt"}) + text = "您好!" + chunks = parser.invoke(text) + assert len(chunks) == 1 and chunks[0].content == text + + file_path = os.path.join(pwd, "../data/test_txt.txt") + chunks = parser.invoke(file_path) + with open(file_path) as f: + content = f.read() + chunks = parser.invoke(file_path) + assert len(chunks) == 1 + assert chunks[0].content == content + assert chunks[0].id == Chunk.generate_hash_id(file_path) + + +def test_docx_parser(): + parser = RecordParserABC.from_config({"type": "docx"}) + + file_path = os.path.join(pwd, "../data/test_docx.docx") + chunks = parser.invoke(file_path) + # Assert the expected result + assert len(chunks) == 30 + assert len(chunks[0].content) > 0 + + +def test_md_parser(): + parser = RecordParserABC.from_config({"type": "md", "cut_depth": 1}) + file_path = os.path.join(pwd, "../data/test_markdown.md") + chunks = parser.invoke(file_path) + assert len(chunks) > 0 + assert chunks[0].name == "test_markdown#0" + + +def test_pdf_parser(): + parser = RecordParserABC.from_config({"type": "pdf"}) + + page = "Header\nContent 1\nContent 2\nFooter" + watermark = "Header" + expected = ["Content 1", "Content 2"] + result = parser._process_single_page( + page, watermark, remove_header=True, remove_footnote=True + ) + assert result == expected + file_path = os.path.join(pwd, "../data/test_pdf.pdf") + chunks = parser.invoke(file_path) + assert chunks[0].name == "test_pdf#0" + + +def test_yuque_parser(): + parser = RecordParserABC.from_config({"type": "yuque", "cut_depth": 1}) + chunks = parser.invoke( + "f6QiFu1gIDEGJIsI6jziOWbE7E9MsFkipeV69NHq@https://yuque-api.antfin-inc.com/api/v2/repos/un8gkl/kg7h1z/docs/odtmme" + ) + assert chunks[0].name == "项目立项#0" diff --git a/tests/unit/builder/test_runner.py b/tests/unit/builder/test_runner.py new file mode 100644 index 00000000..5c19d262 --- /dev/null +++ b/tests/unit/builder/test_runner.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +from kag.builder.runner import CKPT, BuilderRunner + + +def test_ckpt(): + ckpt = CKPT("./") + ckpt.open() + ckpt.add("aaaa") + ckpt.add("bbbb") + ckpt.add("cccc") + ckpt.close() + + ckpt = CKPT("./") + assert ckpt.is_processed("aaaa") + assert ckpt.is_processed("bbbb") + assert ckpt.is_processed("cccc")