diff --git a/kag/examples/FinState/solver/solver.py b/kag/examples/FinState/solver/solver.py index 4af0c15..73bb6fc 100644 --- a/kag/examples/FinState/solver/solver.py +++ b/kag/examples/FinState/solver/solver.py @@ -9,13 +9,15 @@ class FinStateSolver(SolverPipeline): """ def __init__( - self, max_run=3, reflector=None, reasoner=None, generator=None, **kwargs + self, max_run=3, reflector=None, reasoner=None, generator=None, llm_client = None, **kwargs ): super().__init__(max_run, reflector, reasoner, generator, **kwargs) - from kag.common.conf import KAG_CONFIG - KAG_CONFIG.all_config["chat_llm"] - llm: LLMClient = LLMClient.from_config(KAG_CONFIG.all_config["chat_llm"]) + llm = llm_client + if not llm: + from kag.common.conf import KAG_CONFIG + llm: LLMClient = LLMClient.from_config(KAG_CONFIG.all_config["chat_llm"]) + self.table_reasoner = TableReasoner(llm_module = llm, **kwargs) def run(self, question, **kwargs): diff --git a/kag/solver/main_solver.py b/kag/solver/main_solver.py index 1252085..75a3a2c 100644 --- a/kag/solver/main_solver.py +++ b/kag/solver/main_solver.py @@ -12,7 +12,7 @@ import copy from kag.examples.FinState.solver.solver import FinStateSolver -from kag.solver.logic.solver_pipeline import SolverPipeline +from kag.interface import LLMClient from kag.solver.tools.info_processor import ReporterIntermediateProcessTool from kag.common.conf import KAG_CONFIG, KAG_PROJECT_CONF @@ -36,8 +36,10 @@ def invoke( host_addr=host_addr, language=KAG_PROJECT_CONF.language, ) + + llm_client: LLMClient = LLMClient.from_config(KAG_CONFIG.all_config["llm"]) solver = FinStateSolver( - report_tool=report_tool, KAG_PROJECT_ID=project_id + report_tool=report_tool, KAG_PROJECT_ID=project_id, llm_client=llm_client ) answer = solver.run(query, report_tool=report_tool, session_id=session_id) return answer