diff --git a/kag/common/registry/registrable.py b/kag/common/registry/registrable.py index e6ce5d3f..7e0bbacd 100644 --- a/kag/common/registry/registrable.py +++ b/kag/common/registry/registrable.py @@ -812,9 +812,7 @@ def from_config( logger.warn(f"Failed to initialize class {cls}, info: {e}") raise e if len(params) > 0: - raise ConfigurationError( - f"These params are not used for constructing {cls}:\n{params}" - ) + logger.warn(f"These params are not used for constructing {cls}:\n{params}") return instant diff --git a/kag/examples/2wiki/kag_config.yaml b/kag/examples/2wiki/kag_config.yaml index f309e3f8..130d01eb 100644 --- a/kag/examples/2wiki/kag_config.yaml +++ b/kag/examples/2wiki/kag_config.yaml @@ -61,21 +61,28 @@ vectorize_model: *id002 vectorizer: *id002 lf_solver_pipeline: + memory: default_memory + max_iterations: 3 generator: + type: default_generator generate_prompt: type: resp_simple + reflector: + type: default_reflector reasoner: - type: base + type: default_reasoner + lf_planner: + type: default_lf_planner lf_executor: - type: base + type: default_lf_executor force_chunk_retriever: true exact_kg_retriever: type: default_exact_kg_retriever el_num: 5 search_api: &id003 - type: openspg + type: openspg_search_api graph_api: &id004 - type: openspg + type: openspg_graph_api fuzzy_kg_retriever: type: default_fuzzy_kg_retriever el_num: 5 @@ -88,5 +95,5 @@ lf_solver_pipeline: recall_num: 10 rerank_topk: 10 merger: - type: base + type: default_lf_sub_query_res_merger chunk_retriever: *id005 \ No newline at end of file diff --git a/kag/examples/baike/kag_config.yaml b/kag/examples/baike/kag_config.yaml index be0cf65e..539b0e4b 100644 --- a/kag/examples/baike/kag_config.yaml +++ b/kag/examples/baike/kag_config.yaml @@ -61,34 +61,42 @@ kag-indexer: scanner: type: dir + + lf_solver_pipeline: + memory: default_memory + max_iterations: 3 generator: + type: default_generator generate_prompt: - type: default_resp_generator # kag.solver.prompt.default.resp_generator.RespGenerator - #type: resp_simple # solver.prompt.resp_generator.RespGenerator + type: default_resp_generator + reflector: + type: default_reflector reasoner: - type: base + type: default_reasoner + lf_planner: + type: default_lf_planner lf_executor: - type: base + type: default_lf_executor force_chunk_retriever: true exact_kg_retriever: - type: default + type: default_exact_kg_retriever el_num: 5 search_api: &id003 - type: openspg + type: openspg_search_api graph_api: &id004 - type: openspg + type: openspg_graph_api fuzzy_kg_retriever: - type: default + type: default_fuzzy_kg_retriever el_num: 5 vectorize_model: *id002 llm_client: *id001 search_api: *id003 graph_api: *id004 chunk_retriever: &id005 - type: default + type: default_chunk_retriever recall_num: 10 rerank_topk: 10 merger: - type: base - chunk_retriever: *id005 + type: default_lf_sub_query_res_merger + chunk_retriever: *id005 \ No newline at end of file diff --git a/kag/examples/hotpotqa/kag_config.yaml b/kag/examples/hotpotqa/kag_config.yaml index 3d764178..eee5f17b 100644 --- a/kag/examples/hotpotqa/kag_config.yaml +++ b/kag/examples/hotpotqa/kag_config.yaml @@ -59,22 +59,30 @@ runner: type: hotpotqa vectorize_model: *id002 vectorizer: *id002 + lf_solver_pipeline: + memory: default_memory + max_iterations: 3 generator: + type: default_generator generate_prompt: type: resp_simple + reflector: + type: default_reflector reasoner: - type: base + type: default_reasoner + lf_planner: + type: default_lf_planner lf_executor: - type: base + type: default_lf_executor force_chunk_retriever: true exact_kg_retriever: type: default_exact_kg_retriever el_num: 5 search_api: &id003 - type: openspg + type: openspg_search_api graph_api: &id004 - type: openspg + type: openspg_graph_api fuzzy_kg_retriever: type: default_fuzzy_kg_retriever el_num: 5 @@ -87,5 +95,5 @@ lf_solver_pipeline: recall_num: 10 rerank_topk: 10 merger: - type: base + type: default_lf_sub_query_res_merger chunk_retriever: *id005 \ No newline at end of file diff --git a/kag/examples/medicine/kag_config.yaml b/kag/examples/medicine/kag_config.yaml index a4f8a38c..a2fe7f00 100644 --- a/kag/examples/medicine/kag_config.yaml +++ b/kag/examples/medicine/kag_config.yaml @@ -36,18 +36,7 @@ extract_runner: num_chains: 8 scanner: type: csv -lf_solver_pipeline: - generator: - generate_prompt: - type: example_resp_generator - reasoner: - lf_solver: - chunk_retriever: - ner_prompt: - type: example_medical_question_ner - type: kag - kg_retriever: - type: base + llm: *id001 log: level: INFO @@ -85,3 +74,41 @@ vectorize_model: *id002 vectorizer: type: batch vectorize_model: *id002 + +lf_solver_pipeline: + memory: default_memory + max_iterations: 3 + generator: + type: default_generator + generate_prompt: + type: example_resp_generator + reflector: + type: default_reflector + reasoner: + type: default_reasoner + lf_planner: + type: default_lf_planner + lf_executor: + type: default_lf_executor + force_chunk_retriever: true + exact_kg_retriever: + type: default_exact_kg_retriever + el_num: 5 + search_api: &id003 + type: openspg_search_api + graph_api: &id004 + type: openspg_graph_api + fuzzy_kg_retriever: + type: default_fuzzy_kg_retriever + el_num: 5 + vectorize_model: *id002 + llm_client: *id001 + search_api: *id003 + graph_api: *id004 + chunk_retriever: &id005 + type: default_chunk_retriever + recall_num: 10 + rerank_topk: 10 + merger: + type: default_lf_sub_query_res_merger + chunk_retriever: *id005 diff --git a/kag/examples/musique/kag_config.yaml b/kag/examples/musique/kag_config.yaml index 61600d1f..13f60580 100644 --- a/kag/examples/musique/kag_config.yaml +++ b/kag/examples/musique/kag_config.yaml @@ -61,21 +61,28 @@ vectorize_model: *id002 vectorizer: *id002 lf_solver_pipeline: + memory: default_memory + max_iterations: 3 generator: + type: default_generator generate_prompt: type: resp_simple + reflector: + type: default_reflector reasoner: - type: base + type: default_reasoner + lf_planner: + type: default_lf_planner lf_executor: - type: base + type: default_lf_executor force_chunk_retriever: true exact_kg_retriever: type: default_exact_kg_retriever el_num: 5 search_api: &id003 - type: openspg + type: openspg_search_api graph_api: &id004 - type: openspg + type: openspg_graph_api fuzzy_kg_retriever: type: default_fuzzy_kg_retriever el_num: 5 @@ -88,5 +95,5 @@ lf_solver_pipeline: recall_num: 10 rerank_topk: 10 merger: - type: base - chunk_retriever: *id005 + type: default_lf_sub_query_res_merger + chunk_retriever: *id005 \ No newline at end of file diff --git a/kag/examples/riskmining/kag_config.yaml b/kag/examples/riskmining/kag_config.yaml index 1a00d49d..213eb8de 100644 --- a/kag/examples/riskmining/kag_config.yaml +++ b/kag/examples/riskmining/kag_config.yaml @@ -1,13 +1,4 @@ -lf_solver_pipeline: - reasoner: - lf_planner: - logic_form_plan_prompt: - type: riskmining_lf_plan - type: base - generator: - generate_prompt: - type: resp_riskmining -llm: +llm: &id001 api_key: key base_url: https://api.deepseek.com model: deepseek-chat @@ -20,9 +11,50 @@ project: id: '4' language: zh namespace: RiskMining -vectorize_model: &id001 +vectorize_model: &id002 type: mock vector_dimensions: 768 vectorizer: type: batch - vectorize_model: *id001 + vectorize_model: *id002 + + +lf_solver_pipeline: + memory: default_memory + max_iterations: 3 + generator: + type: default_generator + generate_prompt: + type: resp_riskmining + reflector: + type: default_reflector + reasoner: + type: default_reasoner + lf_planner: + logic_form_plan_prompt: + type: riskmining_lf_plan + type: default_lf_planner + lf_executor: + type: default_lf_executor + force_chunk_retriever: true + exact_kg_retriever: + type: default_exact_kg_retriever + el_num: 1 + search_api: &id003 + type: openspg_search_api + graph_api: &id004 + type: openspg_graph_api + fuzzy_kg_retriever: + type: default_fuzzy_kg_retriever + el_num: 1 + vectorize_model: *id002 + llm_client: *id001 + search_api: *id003 + graph_api: *id004 + chunk_retriever: &id005 + type: default_chunk_retriever + recall_num: 10 + rerank_topk: 10 + merger: + type: default_lf_sub_query_res_merger + chunk_retriever: *id005 \ No newline at end of file diff --git a/kag/examples/supplychain/kag_config.yaml b/kag/examples/supplychain/kag_config.yaml index b7f79692..9adc5e98 100644 --- a/kag/examples/supplychain/kag_config.yaml +++ b/kag/examples/supplychain/kag_config.yaml @@ -1,13 +1,4 @@ -lf_solver_pipeline: - reasoner: - lf_planner: - logic_form_plan_prompt: - type: supplychain_lf_plan - type: base - generator: - generate_prompt: - type: resp_supplychain -llm: +llm: &id001 api_key: key base_url: https://api.deepseek.com model: deepseek-chat @@ -20,9 +11,50 @@ project: id: '5' language: zh namespace: SupplyChain -vectorize_model: &id001 +vectorize_model: &id002 type: mock vector_dimensions: 768 vectorizer: type: batch - vectorize_model: *id001 + vectorize_model: *id002 + + +lf_solver_pipeline: + memory: default_memory + max_iterations: 3 + generator: + type: default_generator + generate_prompt: + type: resp_supplychain + reflector: + type: default_reflector + reasoner: + type: default_reasoner + lf_planner: + logic_form_plan_prompt: + type: supplychain_lf_plan + type: default_lf_planner + lf_executor: + type: default_lf_executor + force_chunk_retriever: true + exact_kg_retriever: + type: default_exact_kg_retriever + el_num: 1 + search_api: &id003 + type: openspg_search_api + graph_api: &id004 + type: openspg_graph_api + fuzzy_kg_retriever: + type: default_fuzzy_kg_retriever + el_num: 1 + vectorize_model: *id002 + llm_client: *id001 + search_api: *id003 + graph_api: *id004 + chunk_retriever: &id005 + type: default_chunk_retriever + recall_num: 10 + rerank_topk: 10 + merger: + type: default_lf_sub_query_res_merger + chunk_retriever: *id005 \ No newline at end of file diff --git a/kag/interface/solver/base_model.py b/kag/interface/solver/base_model.py index a3d9d89a..b2325091 100644 --- a/kag/interface/solver/base_model.py +++ b/kag/interface/solver/base_model.py @@ -77,7 +77,7 @@ def get_type_with_gql_format(self): def get_entity_first_std_type(self): type_list = list(self.get_entity_type_set()) if len(type_list) == 0: - return None + return "Entity" return type_list[0] def get_un_std_entity_first_type_or_std(self): @@ -88,7 +88,7 @@ def get_un_std_entity_first_type_or_std(self): elif len(std_type) > 0: return std_type[0] else: - return None + return "Entity" def get_entity_type_or_un_std_list(self): ret = [] @@ -107,7 +107,7 @@ def get_entity_first_type_or_un_std(self): elif len(unstd_type) > 0: return unstd_type[0] else: - return None + return "Entity" def get_entity_type_set(self): entity_types = [] @@ -121,7 +121,10 @@ def get_un_std_entity_type_set(self): for entity_type_info in self.type_set: if entity_type_info.un_std_entity_type is not None: entity_types.append(entity_type_info.un_std_entity_type) - return set(entity_types) + entity_types = set(entity_types) + if len(entity_types) == 0: + return ["Entity"] + return entity_types class SPORelation(SPOBase): diff --git a/kag/solver/execute/default_lf_executor.py b/kag/solver/execute/default_lf_executor.py index 7542f24d..11277abf 100644 --- a/kag/solver/execute/default_lf_executor.py +++ b/kag/solver/execute/default_lf_executor.py @@ -24,7 +24,7 @@ logger = logging.getLogger() -@LFExecutorABC.register("base", as_default=True) +@LFExecutorABC.register("default_lf_executor", as_default=True) class DefaultLFExecutor(LFExecutorABC): def __init__(self, exact_kg_retriever: ExactKgRetriever, fuzzy_kg_retriever: FuzzyKgRetriever, chunk_retriever: ChunkRetriever, merger: LFSubQueryResMerger, force_chunk_retriever: bool = False, @@ -48,6 +48,7 @@ def __init__(self, exact_kg_retriever: ExactKgRetriever, fuzzy_kg_retriever: Fuz self.params['fuzzy_kg_retriever'] = fuzzy_kg_retriever self.params['chunk_retriever'] = chunk_retriever self.params['force_chunk_retriever'] = force_chunk_retriever + self.params['llm_module'] = llm_client # Generate self.generator = LFSubGenerator(llm_client=llm_client) @@ -106,7 +107,10 @@ def _execute_spo_answer(self, req_id: str, query: str, lf: LFPlan, process_info: def _execute_chunk_answer(self, req_id: str, query: str, lf: LFPlan, process_info: Dict, kg_graph: KgGraph, history: List[LFPlan], res: SubQueryResult) -> SubQueryResult: - if not self._judge_sub_answered(res.sub_answer): + if not self._judge_sub_answered(res.sub_answer) or self.force_chunk_retriever: + if self.force_chunk_retriever: + # force chunk retriever, so we clear kg solved answer + process_info['kg_solved_answer'] = [] # chunk retriever all_related_entities = kg_graph.get_all_entity() all_related_entities = list(set(all_related_entities)) diff --git a/kag/solver/execute/default_sub_query_merger.py b/kag/solver/execute/default_sub_query_merger.py index b07568d8..225300a0 100644 --- a/kag/solver/execute/default_sub_query_merger.py +++ b/kag/solver/execute/default_sub_query_merger.py @@ -9,15 +9,15 @@ from kag.solver.retriever.chunk_retriever import ChunkRetriever -@LFSubQueryResMerger.register("base", as_default=True) +@LFSubQueryResMerger.register("default_lf_sub_query_res_merger", as_default=True) class DefaultLFSubQueryResMerger(LFSubQueryResMerger): """ Initializes the base planner. """ - def __init__(self, chunk_retriever: ChunkRetriever = None, vectorize_model: VectorizeModelABC = None, **kwargs): + def __init__(self, chunk_retriever: ChunkRetriever, vectorize_model: VectorizeModelABC = None, **kwargs): super().__init__(**kwargs) - self.chunk_retriever = chunk_retriever or ChunkRetriever.from_config({"type": "default"}) + self.chunk_retriever = chunk_retriever self.vectorize_model = vectorize_model or VectorizeModelABC.from_config( KAG_CONFIG.all_config["vectorize_model"]) self.text_similarity = TextSimilarity(vectorize_model) diff --git a/kag/solver/execute/op_executor/op_output/module/get_executor.py b/kag/solver/execute/op_executor/op_output/module/get_executor.py index 950122e2..81b8c692 100644 --- a/kag/solver/execute/op_executor/op_output/module/get_executor.py +++ b/kag/solver/execute/op_executor/op_output/module/get_executor.py @@ -39,5 +39,5 @@ def executor(self, nl_query: str, logic_node: LogicNode, req_id: str, kg_graph: if isinstance(s_data, RelationData): kg_qa_result.append(str(s_data)) process_info[logic_node.sub_query]['kg_answer'] += f"\n{';'.join(kg_qa_result)}" - process_info['kg_solved_answer'] += f"\n{';'.join(kg_qa_result)}" + process_info['kg_solved_answer'].append(f"\n{';'.join(kg_qa_result)}") return process_info[logic_node.sub_query] diff --git a/kag/solver/implementation/default_generator.py b/kag/solver/implementation/default_generator.py index 96f77346..48cd7970 100644 --- a/kag/solver/implementation/default_generator.py +++ b/kag/solver/implementation/default_generator.py @@ -8,7 +8,7 @@ from kag.solver.implementation.default_memory import DefaultMemory -@KAGGeneratorABC.register("base", as_default=True) +@KAGGeneratorABC.register("default_generator", as_default=True) class DefaultGenerator(KAGGeneratorABC): """ The Generator class is an abstract base class for generating responses using a language model module. diff --git a/kag/solver/implementation/default_memory.py b/kag/solver/implementation/default_memory.py index 3465601e..237dd0a1 100644 --- a/kag/solver/implementation/default_memory.py +++ b/kag/solver/implementation/default_memory.py @@ -9,7 +9,7 @@ logger = logging.getLogger() -@KagMemoryABC.register("base", as_default=True) +@KagMemoryABC.register("default_memory", as_default=True) class DefaultMemory(KagMemoryABC): def __init__( self, diff --git a/kag/solver/implementation/default_reasoner.py b/kag/solver/implementation/default_reasoner.py index 422161e8..1148df08 100644 --- a/kag/solver/implementation/default_reasoner.py +++ b/kag/solver/implementation/default_reasoner.py @@ -10,7 +10,7 @@ logger = logging.getLogger() -@KagReasonerABC.register("base", as_default=True) +@KagReasonerABC.register("default_reasoner", as_default=True) class DefaultReasoner(KagReasonerABC): """ A processor class for handling logical form tasks in language processing. @@ -31,16 +31,16 @@ class DefaultReasoner(KagReasonerABC): def __init__( self, - lf_planner: LFPlannerABC = None, - lf_executor: LFExecutorABC = None, + lf_planner: LFPlannerABC, + lf_executor: LFExecutorABC, llm_client: LLMClient = None, **kwargs, ): super().__init__(llm_client, **kwargs) - self.lf_planner = lf_planner or LFPlannerABC.from_config({"type": "base"}) + self.lf_planner = lf_planner - self.lf_executor = lf_executor or LFExecutorABC.from_config({"type": "base"}) + self.lf_executor = lf_executor self.sub_query_total = 0 self.kg_direct = 0 self.trace_log = [] diff --git a/kag/solver/implementation/default_reflector.py b/kag/solver/implementation/default_reflector.py index 0246ce4a..ea45f701 100644 --- a/kag/solver/implementation/default_reflector.py +++ b/kag/solver/implementation/default_reflector.py @@ -7,7 +7,7 @@ from kag.solver.utils import init_prompt_with_fallback -@KagReflectorABC.register("base", as_default=True) +@KagReflectorABC.register("default_reflector", as_default=True) class DefaultReflector(KagReflectorABC): def __init__( self, @@ -81,4 +81,4 @@ def _refine_query(self, memory: KagMemoryABC, instruction: str): ) if len(update_reason_path) == 0: return None - return update_reason_path[0] + return "\n".join(update_reason_path) diff --git a/kag/solver/logic/solver_pipeline.py b/kag/solver/logic/solver_pipeline.py index 441ae48e..1c2afb4f 100644 --- a/kag/solver/logic/solver_pipeline.py +++ b/kag/solver/logic/solver_pipeline.py @@ -14,22 +14,25 @@ class SolverPipeline(Registrable): - def __init__(self, max_run=3, reflector: KagReflectorABC = None, reasoner: KagReasonerABC = None, - generator: KAGGeneratorABC = None, **kwargs): + def __init__(self, reflector: KagReflectorABC, reasoner: KagReasonerABC, + generator: KAGGeneratorABC, memory: str = "default_memory", max_iterations=3, **kwargs): """ Initializes the think-and-act loop class. - :param max_run: Maximum number of runs to limit the thinking and acting loop, defaults to 3. + :param max_iterations: Maximum number of iteration to limit the thinking and acting loop, defaults to 3. :param reflector: Reflector instance for reflect tasks. :param reasoner: Reasoner instance for reasoning about tasks. :param generator: Generator instance for generating actions. + :param memory: Assign memory store type """ super().__init__(**kwargs) - self.max_run = max_run + self.max_iterations = max_iterations - self.reflector = reflector or KagReflectorABC.from_config({"type": "base"}) - self.reasoner = reasoner or KagReasonerABC.from_config({"type": "base"}) - self.generator = generator or KAGGeneratorABC.from_config({"type": "base"}) + self.reflector = reflector + self.reasoner = reasoner + self.generator = generator + + self.memory_type = memory self.param = kwargs @@ -49,9 +52,9 @@ def run(self, question, **kwargs): trace_log = [] present_instruction = instruction run_cnt = 0 - memory = KagMemoryABC.from_config({"type": "base"}) + memory = KagMemoryABC.from_config({"type": self.memory_type}) - while not if_finished and run_cnt < self.max_run: + while not if_finished and run_cnt < self.max_iterations: run_cnt += 1 logger.debug("present_instruction is:{}".format(present_instruction)) # Attempt to solve the current instruction and get the answer, supporting facts, and history log diff --git a/kag/solver/main_solver.py b/kag/solver/main_solver.py index b7ac9050..cb6f5517 100644 --- a/kag/solver/main_solver.py +++ b/kag/solver/main_solver.py @@ -14,7 +14,7 @@ from kag.solver.logic.solver_pipeline import SolverPipeline from kag.solver.tools.info_processor import ReporterIntermediateProcessTool -from kag.common.conf import KAG_CONFIG,KAG_PROJECT_CONF +from kag.common.conf import KAG_CONFIG, KAG_PROJECT_CONF class SolverMain: @@ -35,15 +35,15 @@ def invoke( language=KAG_PROJECT_CONF.language ) default_pipeline_config = { + 'max_iterations': 3, + 'memory': 'default_memory', 'generator': { 'generate_prompt': { 'type': 'default_resp_generator' - } + }, + 'type': 'default_generator' }, 'reasoner': { - 'lf_planner': { - 'type': 'base' - }, 'lf_executor': { 'chunk_retriever': { 'recall_num': 10, @@ -53,10 +53,10 @@ def invoke( 'exact_kg_retriever': { 'el_num': 5, 'graph_api': { - 'type': 'openspg' + 'type': 'openspg_graph_api' }, 'search_api': { - 'type': 'openspg' + 'type': 'openspg_search_api' }, 'type': 'default_exact_kg_retriever' }, @@ -64,10 +64,10 @@ def invoke( 'fuzzy_kg_retriever': { 'el_num': 5, 'graph_api': { - 'type': 'openspg' + 'type': 'openspg_graph_api' }, 'search_api': { - 'type': 'openspg' + 'type': 'openspg_search_api' }, 'type': 'default_fuzzy_kg_retriever', }, @@ -75,13 +75,19 @@ def invoke( 'chunk_retriever': { 'recall_num': 10, 'rerank_topk': 10, - 'type': 'default_chunk_retriever', + 'type': 'default_chunk_retriever' }, - 'type': 'base' + 'type': 'default_lf_sub_query_res_merger' }, - 'type': 'base' + 'type': 'default_lf_executor' }, - 'type': 'base' + 'lf_planner': { + 'type': 'default_lf_planner' + }, + 'type': 'default_reasoner' + }, + 'reflector': { + 'type': 'default_reflector' } } conf = copy.deepcopy(KAG_CONFIG.all_config.get("lf_solver_pipeline", default_pipeline_config)) diff --git a/kag/solver/plan/default_lf_planner.py b/kag/solver/plan/default_lf_planner.py index 7059386b..009e5e76 100644 --- a/kag/solver/plan/default_lf_planner.py +++ b/kag/solver/plan/default_lf_planner.py @@ -17,7 +17,7 @@ logger = logging.getLogger() -@LFPlannerABC.register("base", as_default=True) +@LFPlannerABC.register("default_lf_planner", as_default=True) class DefaultLFPlanner(LFPlannerABC): """ Planner class that extends the base planner functionality to generate sub-queries and logic forms. @@ -75,9 +75,18 @@ def _split_sub_query(self, logic_nodes: List[LogicNode]) -> List[LFPlan]: plan_result.append(LFPlan(query=k, lf_nodes=v)) return plan_result + def _process_output_query(self, question, sub_query: str): + if sub_query is None: + return question + if 'output' == sub_query.lower(): + return f"output `{question}` answer:" + return sub_query + def _parse_lf(self, question, sub_querys, logic_forms) -> List[LFPlan]: if sub_querys is None: sub_querys = [] + # process sub query + sub_querys = [ self._process_output_query(question, q) for q in sub_querys] parsed_logic_nodes = self.parser.parse_logic_form_set( logic_forms, sub_querys, question ) diff --git a/kag/solver/prompt/default/spo_retrieval.py b/kag/solver/prompt/default/spo_retrieval.py index 3338e878..c32cd929 100644 --- a/kag/solver/prompt/default/spo_retrieval.py +++ b/kag/solver/prompt/default/spo_retrieval.py @@ -47,12 +47,12 @@ class SpoRetrieval(PromptABC): ] } ], - "模板": { + "任务": { "问题": "$question", "SPO 提及": "$mention", - "SPO 候选项": "$candis", - "output": [] - } + "SPO 候选项": "$candis" + }, + "output": "提供一个JSON列表,其中包含根据SPO提及内容选出的最佳回答问题的SPO候选者。" } """ template_en = """{ @@ -107,11 +107,13 @@ class SpoRetrieval(PromptABC): def template_variables(self) -> List[str]: return ["question", "mention", "candis"] - def parse_response(self, response: Dict, **kwargs): + def parse_response(self, response, **kwargs): logger.debug( f"SpoRetrieval {response} mention:{self.template_variables_value.get('mention', '')} " f"candis:{self.template_variables_value.get('candis', '')}" ) + if not isinstance(response, dict): + return [] if "output" in response: return response["output"] if "Output" in response: diff --git a/kag/solver/retriever/base/kg_retriever.py b/kag/solver/retriever/base/kg_retriever.py index 32f46f47..b6e92f23 100644 --- a/kag/solver/retriever/base/kg_retriever.py +++ b/kag/solver/retriever/base/kg_retriever.py @@ -25,11 +25,11 @@ def __init__(self, el_num=5, llm_client: LLMClient = None, vectorize_model: Vect "KAG_PROJECT_HOST_ADDR": KAG_PROJECT_CONF.host_addr })) self.graph_api = graph_api or GraphApiABC.from_config({ - "type": "openspg"} + "type": "openspg_graph_api"} ) self.search_api = search_api or SearchApiABC.from_config({ - "type": "openspg" + "type": "openspg_search_api" }) self.vectorize_model = vectorize_model or VectorizeModelABC.from_config( diff --git a/kag/solver/retriever/chunk_retriever.py b/kag/solver/retriever/chunk_retriever.py index 98e2fb62..918afdf0 100644 --- a/kag/solver/retriever/chunk_retriever.py +++ b/kag/solver/retriever/chunk_retriever.py @@ -24,11 +24,11 @@ def __init__(self, recall_num: int = 10, "KAG_PROJECT_HOST_ADDR": KAG_PROJECT_CONF.host_addr })) self.graph_api = graph_api or GraphApiABC.from_config({ - "type": "openspg"} + "type": "openspg_graph_api"} ) self.search_api = search_api or SearchApiABC.from_config({ - "type": "openspg" + "type": "openspg_search_api" }) diff --git a/kag/solver/tools/graph_api/impl/openspg_graph_api.py b/kag/solver/tools/graph_api/impl/openspg_graph_api.py index 807831ac..7efaed40 100644 --- a/kag/solver/tools/graph_api/impl/openspg_graph_api.py +++ b/kag/solver/tools/graph_api/impl/openspg_graph_api.py @@ -49,7 +49,7 @@ def convert_node_to_json(node_str): } -@GraphApiABC.register("openspg", as_default=True) +@GraphApiABC.register("openspg_graph_api", as_default=True) class OpenSPGGraphApi(GraphApiABC): def __init__(self, project_id=None, host_addr=None, **kwargs): super().__init__(**kwargs) diff --git a/kag/solver/tools/search_api/impl/openspg_search_api.py b/kag/solver/tools/search_api/impl/openspg_search_api.py index 56b4eb93..26610de2 100644 --- a/kag/solver/tools/search_api/impl/openspg_search_api.py +++ b/kag/solver/tools/search_api/impl/openspg_search_api.py @@ -5,7 +5,7 @@ from knext.search.client import SearchClient -@SearchApiABC.register("openspg", as_default=True) +@SearchApiABC.register("openspg_search_api", as_default=True) class OpenSPGSearchAPI(SearchApiABC): def __init__(self, project_id=None, host_addr=None, **kwargs): super().__init__(**kwargs)