diff --git a/README.md b/README.md index bd164b5..bb5f3a8 100644 --- a/README.md +++ b/README.md @@ -45,8 +45,8 @@ While ChatGPT did not make mistakes with the basic arithmetic operations, it cou The reason the `gpt-4o-mini` model is able to count the number of 'r's correctly is because DQA lets it use a function to calculate the occurrences of a specific character or a sequence of characters in a string. ### The agent workflow -The approximate current workflow for DQA can be summarised as follows. -![Workflow](./diagrams/workflow.svg) +The approximate the _Structured Sub-Question ReAct_ (SSQReAct) workflow can be summarised as follows. +![SSQReAct workflow](./diagrams/ssqreact.svg) The DQA workflow uses a [self-discover](https://arxiv.org/abs/2402.03620) "agent" to produce a reasoning structure but not answer the question. Similar to the tutorial [^1], the DQA workflow performs query decomposition with respect to the reasoning structure to ensure that complex queries are not directly sent to the LLM. Instead, sub-questions (i.e., decompositions of the complex query) that help answer the complex query are sent. The workflow further optimises the sub-questions through a query refinement step, which loops if necessary, for a maximum number of allowed iterations. diff --git a/diagrams/workflow.svg b/diagrams/ssqreact.svg similarity index 100% rename from diagrams/workflow.svg rename to diagrams/ssqreact.svg diff --git a/src/dqa.py b/src/dqa.py new file mode 100644 index 0000000..9a8b574 --- /dev/null +++ b/src/dqa.py @@ -0,0 +1,364 @@ +# Copyright 2024 Anirban Basu + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Difficult Questions Attempted module containing various workflows.""" + +try: + from icecream import ic +except ImportError: # Graceful fallback if IceCream isn't installed. + ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa + +import sys +from tqdm import tqdm +import asyncio + +# Weaker LLMs may generate horrible JSON strings. +# `dirtyjson` is more lenient than `json` in parsing JSON strings. +from typing import List +from llama_index.tools.arxiv import ArxivToolSpec + +from llama_index.tools.wikipedia import WikipediaToolSpec +from llama_index.tools.tavily_research import TavilyToolSpec + +from llama_index.tools.yahoo_finance import YahooFinanceToolSpec +from llama_index.core.tools import FunctionTool + + +from tools import ( + DuckDuckGoFullSearchOnlyToolSpec, + StringFunctionsToolSpec, + BasicArithmeticCalculatorSpec, + MathematicalFunctionsSpec, +) +from utils import ( + APP_TITLE_SHORT, + FAKE_STRING, + ToolNames, + get_terminal_size, +) # , parse_env, EnvironmentVariables + + +from llama_index.core.llms.llm import LLM + +from workflows.ssq_react import StructuredSubQuestionReActWorkflow + + +class DQAEngine: + """The Difficult Questions Attempted engine.""" + + def __init__(self, llm: LLM | None = None): + """ + Initialize the Difficult Questions Attempted engine. + + Args: + llm (LLM): The function calling LLM instance to use. + """ + self.llm = llm + # Add tool specs + self.tools: List[FunctionTool] = [] + # Mandatory tools + self.tools.extend(StringFunctionsToolSpec().to_tool_list()) + self.tools.extend(BasicArithmeticCalculatorSpec().to_tool_list()) + + # TODO: Populate the tools based on toolset names specified in the environment variables? + self.tools.extend(DuckDuckGoFullSearchOnlyToolSpec().to_tool_list()) + + def _are_tools_present(self, tool_names: list[str]) -> bool: + """ + Check if the tools with the given names are present in the current set of tools. + + Args: + tool_names (list[str]): The names of the tools to check. + + Returns: + bool: True if all the tools are present, False otherwise. + """ + return all( + tool_name in [tool.metadata.name for tool in self.tools] + for tool_name in tool_names + ) + + def _remove_tools_by_names(self, tool_names: list[str]): + """ + Remove the tools with the given names from the current set of tools. + + Args: + tool_names (list[str]): The names of the tools to remove. + """ + self.tools = [ + tool for tool in self.tools if tool.metadata.name not in tool_names + ] + + def is_toolset_present(self, toolset_name: str) -> bool: + """ + Check if the tools for the given toolset are present in the current set of tools. + + Args: + toolset_name (str): The name of the toolset to check. + + Returns: + bool: True if the tools are present, False otherwise. + """ + if toolset_name == ToolNames.TOOL_NAME_ARXIV: + return self._are_tools_present( + [tool.metadata.name for tool in ArxivToolSpec().to_tool_list()] + ) + elif toolset_name == ToolNames.TOOL_NAME_BASIC_ARITHMETIC_CALCULATOR: + return self._are_tools_present( + [ + tool.metadata.name + for tool in BasicArithmeticCalculatorSpec().to_tool_list() + ] + ) + elif toolset_name == ToolNames.TOOL_NAME_MATHEMATICAL_FUNCTIONS: + return self._are_tools_present( + [ + tool.metadata.name + for tool in MathematicalFunctionsSpec().to_tool_list() + ] + ) + elif toolset_name == ToolNames.TOOL_NAME_DUCKDUCKGO: + return self._are_tools_present( + [ + tool.metadata.name + for tool in DuckDuckGoFullSearchOnlyToolSpec().to_tool_list() + ] + ) + elif toolset_name == ToolNames.TOOL_NAME_STRING_FUNCTIONS: + return self._are_tools_present( + [ + tool.metadata.name + for tool in StringFunctionsToolSpec().to_tool_list() + ] + ) + elif toolset_name == ToolNames.TOOL_NAME_TAVILY: + return self._are_tools_present( + [ + tool.metadata.name + for tool in TavilyToolSpec(api_key=FAKE_STRING).to_tool_list() + ] + ) + elif toolset_name == ToolNames.TOOL_NAME_WIKIPEDIA: + return self._are_tools_present( + [tool.metadata.name for tool in WikipediaToolSpec().to_tool_list()] + ) + elif toolset_name == ToolNames.TOOL_NAME_YAHOO_FINANCE: + return self._are_tools_present( + [tool.metadata.name for tool in YahooFinanceToolSpec().to_tool_list()] + ) + + def get_selected_web_search_toolset(self) -> str: + """ + Get the name of the web search toolset currently selected. + + Returns: + str: The name of the web search toolset. + """ + if self.is_toolset_present(ToolNames.TOOL_NAME_DUCKDUCKGO): + return ToolNames.TOOL_NAME_DUCKDUCKGO + elif self.is_toolset_present(ToolNames.TOOL_NAME_TAVILY): + return ToolNames.TOOL_NAME_TAVILY + else: + # Unknown or no toolset selected. + return ToolNames.TOOL_NAME_SELECTION_DISABLE + + def remove_toolset(self, toolset_name: str): + """ + Remove the tools for the given toolset from the current set of tools. + + Args: + toolset_name (str): The name of the toolset to remove. + """ + if toolset_name == ToolNames.TOOL_NAME_ARXIV: + self._remove_tools_by_names( + [tool.metadata.name for tool in ArxivToolSpec().to_tool_list()] + ) + elif toolset_name == ToolNames.TOOL_NAME_BASIC_ARITHMETIC_CALCULATOR: + self._remove_tools_by_names( + [ + tool.metadata.name + for tool in BasicArithmeticCalculatorSpec().to_tool_list() + ] + ) + elif toolset_name == ToolNames.TOOL_NAME_MATHEMATICAL_FUNCTIONS: + self._remove_tools_by_names( + [ + tool.metadata.name + for tool in MathematicalFunctionsSpec().to_tool_list() + ] + ) + elif toolset_name == ToolNames.TOOL_NAME_DUCKDUCKGO: + self._remove_tools_by_names( + [ + tool.metadata.name + for tool in DuckDuckGoFullSearchOnlyToolSpec().to_tool_list() + ] + ) + elif toolset_name == ToolNames.TOOL_NAME_STRING_FUNCTIONS: + self._remove_tools_by_names( + [ + tool.metadata.name + for tool in StringFunctionsToolSpec().to_tool_list() + ] + ) + elif toolset_name == ToolNames.TOOL_NAME_TAVILY: + self._remove_tools_by_names( + [ + tool.metadata.name + for tool in TavilyToolSpec(api_key=FAKE_STRING).to_tool_list() + ] + ) + elif toolset_name == ToolNames.TOOL_NAME_WIKIPEDIA: + self._remove_tools_by_names( + [tool.metadata.name for tool in WikipediaToolSpec().to_tool_list()] + ) + elif toolset_name == ToolNames.TOOL_NAME_YAHOO_FINANCE: + self._remove_tools_by_names( + [tool.metadata.name for tool in YahooFinanceToolSpec().to_tool_list()] + ) + + def add_or_set_toolset( + self, + toolset_name: str, + api_key: str | None = None, + remove_existing: bool = True, + ): + """ + Add or set the tools for the given toolset. + + Args: + toolset_name (str): The name of the toolset to add or set. + api_key (str): The API key to use with the toolset. + remove_existing (bool): Whether to remove the existing tools for the toolset before adding the new ones. + """ + + # Remove the existing tools for the toolset to avoid duplicates + if remove_existing: + self.remove_toolset(toolset_name) + + if toolset_name == ToolNames.TOOL_NAME_ARXIV: + self.tools.extend(ArxivToolSpec().to_tool_list()) + elif toolset_name == ToolNames.TOOL_NAME_BASIC_ARITHMETIC_CALCULATOR: + self.tools.extend(BasicArithmeticCalculatorSpec().to_tool_list()) + elif toolset_name == ToolNames.TOOL_NAME_MATHEMATICAL_FUNCTIONS: + self.tools.extend(MathematicalFunctionsSpec().to_tool_list()) + elif toolset_name == ToolNames.TOOL_NAME_DUCKDUCKGO: + self.tools.extend(DuckDuckGoFullSearchOnlyToolSpec().to_tool_list()) + elif toolset_name == ToolNames.TOOL_NAME_STRING_FUNCTIONS: + self.tools.extend(StringFunctionsToolSpec().to_tool_list()) + elif toolset_name == ToolNames.TOOL_NAME_TAVILY: + self.tools.extend(TavilyToolSpec(api_key=api_key).to_tool_list()) + elif toolset_name == ToolNames.TOOL_NAME_WIKIPEDIA: + self.tools.extend(WikipediaToolSpec().to_tool_list()) + elif toolset_name == ToolNames.TOOL_NAME_YAHOO_FINANCE: + self.tools.extend(YahooFinanceToolSpec().to_tool_list()) + + def set_web_search_tool( + self, search_tool: str, search_tool_api_key: str | None = None + ): + """ + Set the web search tool to use for the Difficult Questions Attempted engine. + + Args: + search_tool (str): The web search tool to use. + api_key (str): The API key to use with the web search tool. + """ + + self.remove_toolset(ToolNames.TOOL_NAME_DUCKDUCKGO) + self.remove_toolset(ToolNames.TOOL_NAME_TAVILY) + + if search_tool != ToolNames.TOOL_NAME_SELECTION_DISABLE: + self.add_or_set_toolset( + search_tool, api_key=search_tool_api_key, remove_existing=False + ) + + def get_descriptive_tools_dataframe(self): + """ + Get a dataframe consisting of the names and descriptions of the tools currently available. + """ + return [ + [ + f"`{tool.metadata.name}`", + tool.metadata.description.split("\n\n")[1].strip(), + ] + for tool in self.tools + ] + + async def run(self, query: str): + """ + Run the Difficult Questions Attempted engine with the given query. + + Args: + query (str): The query to process. + + Yields: + tuple: A tuple containing the done status, the number of finished steps, the total number of steps, and the message + for each step of the workflow. The message is the response to the query when the workflow is done. + """ + # Instantiating the ReAct workflow instead may not be always enough to get the desired responses to certain questions. + self.workflow = StructuredSubQuestionReActWorkflow( + llm=self.llm, + tools=self.tools, + timeout=180, + verbose=False, + ) + # No need for this, see: https://github.com/run-llama/llama_index/discussions/15838#discussioncomment-10553154 + # self.workflow.add_workflows( + # react_workflow=ReActWorkflow( + # llm=self.llm, tools=self.tools, timeout=60, verbose=True + # ) + # ) + # No longer usable in this way, due to breaking changes in LlamaIndex Workflows. + # task = asyncio.create_task( + # self.workflow.run( + # query=query, + # ) + # ) + task: asyncio.Future = self.workflow.run( + query=query, + ) + done: bool = False + total_steps: int = 0 + finished_steps: int = 0 + terminal_columns, _ = get_terminal_size() + progress_bar = tqdm( + total=total_steps, + leave=False, + unit="step", + ncols=int(terminal_columns / 2), + desc=APP_TITLE_SHORT, + colour="yellow", + ) + async for ev in self.workflow.stream_events(): + total_steps = ev.total_steps + finished_steps = ev.finished_steps + print(f"\n{str(ev.msg)}", flush=True) + # TODO: Is tqdm.write better than printf? + # tqdm.write(f"\n{str(ev.msg)}") + progress_bar.reset(total=total_steps) + progress_bar.update(finished_steps) + progress_bar.refresh() + yield done, finished_steps, total_steps, ev.msg + try: + done, _ = await asyncio.wait([task]) + if done: + result = task.result() + except Exception as e: + result = f"\nException in running the workflow(s). Type: {type(e).__name__}. Message: '{str(e)}'" + # Set this to done, otherwise another workflow call cannot be made. + done = True + print(result, file=sys.stderr) + finally: + progress_bar.close() + yield done, finished_steps, total_steps, result diff --git a/src/webapp.py b/src/webapp.py index 16fda70..6a7cfea 100644 --- a/src/webapp.py +++ b/src/webapp.py @@ -23,7 +23,7 @@ import gradio as gr -from workflows.dqa import DQAEngine +from dqa import DQAEngine from utils import ( APP_TITLE_FULL, COLON_STRING, diff --git a/src/workflows/dqa.py b/src/workflows/ssq_react.py similarity index 54% rename from src/workflows/dqa.py rename to src/workflows/ssq_react.py index e4bb97e..284cfb8 100644 --- a/src/workflows/dqa.py +++ b/src/workflows/ssq_react.py @@ -12,28 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Difficult Questions Attempted module containing various workflows.""" +"""Structured Sub-Question ReAct (SSQReAct) workflow.""" try: from icecream import ic except ImportError: # Graceful fallback if IceCream isn't installed. ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa -import sys -from tqdm import tqdm import asyncio # Weaker LLMs may generate horrible JSON strings. # `dirtyjson` is more lenient than `json` in parsing JSON strings. import dirtyjson as json from typing import Any, List -from llama_index.tools.arxiv import ArxivToolSpec - -from llama_index.tools.wikipedia import WikipediaToolSpec -from llama_index.tools.tavily_research import TavilyToolSpec - -from llama_index.tools.yahoo_finance import YahooFinanceToolSpec -from llama_index.core.tools import FunctionTool from llama_index.core.workflow import ( step, @@ -44,19 +35,6 @@ StopEvent, ) -from tools import ( - DuckDuckGoFullSearchOnlyToolSpec, - StringFunctionsToolSpec, - BasicArithmeticCalculatorSpec, - MathematicalFunctionsSpec, -) -from utils import ( - APP_TITLE_SHORT, - FAKE_STRING, - ToolNames, - get_terminal_size, -) # , parse_env, EnvironmentVariables - from llama_index.core.llms.llm import LLM from llama_index.core.tools.types import BaseTool @@ -66,9 +44,9 @@ from workflows.self_discover import SelfDiscoverWorkflow -class DQAReasoningStructureEvent(Event): +class SSQReActReasoningStructureEvent(Event): """ - Event to handle reasoning structure for DQA. + Event to handle reasoning structure for SSQReAct. Fields: question (str): The question. @@ -78,9 +56,9 @@ class DQAReasoningStructureEvent(Event): reasoning_structure: str -class DQASequentialQueryEvent(Event): +class SSQReActSequentialQueryEvent(Event): """ - Event to handle a DQA query from a list of questions. This event is used to handle questions sequentially. + Event to handle a SSQReAct query from a list of questions. This event is used to handle questions sequentially. The question index is used to keep track of the current question being handled. The list of questions are expected to be stored in the context. @@ -97,9 +75,9 @@ class DQASequentialQueryEvent(Event): question_index: int = 0 -class DQAAnswerEvent(Event): +class SSQReActAnswerEvent(Event): """ - Event to handle a DQA answer. + Event to handle a SSQReAct answer. Fields: question (str): The question. @@ -112,7 +90,7 @@ class DQAAnswerEvent(Event): sources: List[Any] = [] -class DQAReviewSubQuestionEvent(Event): +class SSQReActReviewSubQuestionEvent(Event): """ Event to review the sub-questions. @@ -125,8 +103,8 @@ class DQAReviewSubQuestionEvent(Event): satisfied: bool = False -class DQAWorkflow(Workflow): - """A workflow implementation for DQA: Difficult Questions Attempted.""" +class StructuredSubQuestionReActWorkflow(Workflow): + """A workflow implementation for SSQReAct: Structured Sub-Question ReAct.""" KEY_ORIGINAL_QUERY = "original_query" KEY_REASONING_STRUCTURE = "reasoning_structure" @@ -144,7 +122,7 @@ def __init__( **kwargs: Any, ) -> None: """ - Initialize the DQA workflow. + Initialize the SSQReAct workflow. Args: llm (LLM): The LLM instance to use. @@ -164,7 +142,7 @@ def __init__( @step async def start( self, ctx: Context, ev: StartEvent - ) -> DQAReasoningStructureEvent | StopEvent: + ) -> SSQReActReasoningStructureEvent | StopEvent: """ As a start event of the workflow, this step receives the original query and stores it in the context. @@ -173,12 +151,14 @@ async def start( ev (StartEvent): The start event. Returns: - DQAReasoningStructureEvent | StopEvent: The event containing the reasoning structure or the event to stop the workflow. + SSQReActReasoningStructureEvent | StopEvent: The event containing the reasoning structure or the event to stop the workflow. """ if hasattr(ev, "query"): - await ctx.set(DQAWorkflow.KEY_ORIGINAL_QUERY, ev.query) await ctx.set( - DQAWorkflow.KEY_REACT_CONTEXT, + StructuredSubQuestionReActWorkflow.KEY_ORIGINAL_QUERY, ev.query + ) + await ctx.set( + StructuredSubQuestionReActWorkflow.KEY_REACT_CONTEXT, ( "\nPAY ATTENTION since the given question may make implicit references to information in the context. " "If you find or can deduce the answer to the given question from the context below, PLEASE refrain from calling any further tools. " @@ -193,7 +173,7 @@ async def start( self._total_steps += 1 ctx.write_event_to_stream( WorkflowStatusEvent( - msg=f"Generating structured reasoning for the query:\n\t{await ctx.get(DQAWorkflow.KEY_ORIGINAL_QUERY)}", + msg=f"Generating structured reasoning for the query:\n\t{await ctx.get(StructuredSubQuestionReActWorkflow.KEY_ORIGINAL_QUERY)}", total_steps=self._total_steps, finished_steps=self._finished_steps, ) @@ -201,7 +181,7 @@ async def start( self_discover_workflow = SelfDiscoverWorkflow( llm=self.llm, - # Let's set the timeout of the ReAct workflow to half of the DQA workflow's timeout. + # Let's set the timeout of the ReAct workflow to half of the SSQReAct workflow's timeout. timeout=self._timeout / 2, verbose=self._verbose, plan_only=True, @@ -224,12 +204,12 @@ async def start( response = self_discover_task.result() self._finished_steps += 1 - return DQAReasoningStructureEvent(reasoning_structure=response) + return SSQReActReasoningStructureEvent(reasoning_structure=response) @step async def query( - self, ctx: Context, ev: DQAReasoningStructureEvent - ) -> DQASequentialQueryEvent | DQAReviewSubQuestionEvent | StopEvent: + self, ctx: Context, ev: SSQReActReasoningStructureEvent + ) -> SSQReActSequentialQueryEvent | SSQReActReviewSubQuestionEvent | StopEvent: """ This step receives the structured reasoning for the query. It then asks the LLM to decompose the query into sub-questions. Upon decomposition, it emits every @@ -238,19 +218,22 @@ async def query( Args: ctx (Context): The context object. - ev (DQAStructuredReasoningEvent): The event containing the structured reasoning. + ev (SSQReActStructuredReasoningEvent): The event containing the structured reasoning. Returns: - DQASequentialQueryEvent | DQAReviewSubQuestionEvent | StopEvent: The event containing the sub-question index to process or the event + SSQReActSequentialQueryEvent | SSQReActReviewSubQuestionEvent | StopEvent: The event containing the sub-question index to process or the event to review the sub-questions or the event to stop the workflow. """ - await ctx.set(DQAWorkflow.KEY_REASONING_STRUCTURE, ev.reasoning_structure) + await ctx.set( + StructuredSubQuestionReActWorkflow.KEY_REASONING_STRUCTURE, + ev.reasoning_structure, + ) self._total_steps += 1 ctx.write_event_to_stream( WorkflowStatusEvent( - msg=f"Assessing query and plan:\n\t{await ctx.get(DQAWorkflow.KEY_ORIGINAL_QUERY)}", + msg=f"Assessing query and plan:\n\t{await ctx.get(StructuredSubQuestionReActWorkflow.KEY_ORIGINAL_QUERY)}", total_steps=self._total_steps, finished_steps=self._finished_steps, ) @@ -271,51 +254,53 @@ async def query( "Question: Is Hamlet more common on IMDB than Comedy of Errors?\n" "Decompositions:\n" "{\n" - f' "{DQAWorkflow.KEY_SUB_QUESTIONS}": [\n' + f' "{StructuredSubQuestionReActWorkflow.KEY_SUB_QUESTIONS}": [\n' ' "How many listings of Hamlet are there on IMDB?",\n' ' "How many listings of Comedy of Errors is there on IMDB?"\n' " ],\n" - f' "{DQAWorkflow.KEY_SATISFIED}": true\n' + f' "{StructuredSubQuestionReActWorkflow.KEY_SATISFIED}": true\n' "}" "\n\nExample 2:\n" "Question: What is the capital city of Japan?\n" "Decompositions:\n" "{\n" - f' "{DQAWorkflow.KEY_SUB_QUESTIONS}": ["What is the capital city of Japan?"],\n' - f' "{DQAWorkflow.KEY_SATISFIED}": true\n' + f' "{StructuredSubQuestionReActWorkflow.KEY_SUB_QUESTIONS}": ["What is the capital city of Japan?"],\n' + f' "{StructuredSubQuestionReActWorkflow.KEY_SATISFIED}": true\n' "}\n" "Note that this question above needs no decomposition. Hence, the original question is output as the only sub-question." "\n\nExample 3:\n" "Question: Are there more hydrogen atoms in methyl alcohol than in ethyl alcohol?\n" "Decompositions:\n" "{\n" - f' "{DQAWorkflow.KEY_SUB_QUESTIONS}": [\n' + f' "{StructuredSubQuestionReActWorkflow.KEY_SUB_QUESTIONS}": [\n' ' "How many hydrogen atoms are there in methyl alcohol?",\n' ' "How many hydrogen atoms are there in ethyl alcohol?",\n' ' "What is the chemical composition of alcohol?"\n' " ],\n" - f' "{DQAWorkflow.KEY_SATISFIED}": false\n' + f' "{StructuredSubQuestionReActWorkflow.KEY_SATISFIED}": false\n' "}\n" "Note that the third sub-question is unnecessary and should not be included. Hence, the value of the satisfied flag is set to false." "\n\nAlways, respond in pure JSON without any Markdown, like this:\n" "{\n" - f' "{DQAWorkflow.KEY_SUB_QUESTIONS}": [\n' + f' "{StructuredSubQuestionReActWorkflow.KEY_SUB_QUESTIONS}": [\n' ' "sub question 1",\n' ' "sub question 2",\n' ' "sub question 3"\n' " ],\n" - f' "{DQAWorkflow.KEY_SATISFIED}": true or false\n' + f' "{StructuredSubQuestionReActWorkflow.KEY_SATISFIED}": true or false\n' "}" "\nDO NOT hallucinate!" - f"\n\nHere is the user question: {await ctx.get(DQAWorkflow.KEY_ORIGINAL_QUERY)}" - f"\n\nAnd, here is the corresponding reasoning structure:\n{await ctx.get(DQAWorkflow.KEY_REASONING_STRUCTURE)}" + f"\n\nHere is the user question: {await ctx.get(StructuredSubQuestionReActWorkflow.KEY_ORIGINAL_QUERY)}" + f"\n\nAnd, here is the corresponding reasoning structure:\n{await ctx.get(StructuredSubQuestionReActWorkflow.KEY_REASONING_STRUCTURE)}" ) response = await self.llm.acomplete(prompt) self._finished_steps += 1 response_obj = json.loads(str(response)) - sub_questions = response_obj[DQAWorkflow.KEY_SUB_QUESTIONS] - satisfied = response_obj[DQAWorkflow.KEY_SATISFIED] + sub_questions = response_obj[ + StructuredSubQuestionReActWorkflow.KEY_SUB_QUESTIONS + ] + satisfied = response_obj[StructuredSubQuestionReActWorkflow.KEY_SATISFIED] ctx.write_event_to_stream( WorkflowStatusEvent( @@ -325,24 +310,29 @@ async def query( ) ) - await ctx.set(DQAWorkflow.KEY_SUB_QUESTIONS, sub_questions) - await ctx.set(DQAWorkflow.KEY_SUB_QUESTIONS_COUNT, len(sub_questions)) + await ctx.set( + StructuredSubQuestionReActWorkflow.KEY_SUB_QUESTIONS, sub_questions + ) + await ctx.set( + StructuredSubQuestionReActWorkflow.KEY_SUB_QUESTIONS_COUNT, + len(sub_questions), + ) # Ignore the satisfied flag if there is only one sub-question. if len(sub_questions) == 1: - return DQASequentialQueryEvent() + return SSQReActSequentialQueryEvent() else: if satisfied: - return DQASequentialQueryEvent() + return SSQReActSequentialQueryEvent() else: - return DQAReviewSubQuestionEvent( + return SSQReActReviewSubQuestionEvent( questions=sub_questions, satisfied=satisfied ) @step async def review_sub_questions( - self, ctx: Context, ev: DQAReviewSubQuestionEvent - ) -> DQASequentialQueryEvent | DQAReviewSubQuestionEvent: + self, ctx: Context, ev: SSQReActReviewSubQuestionEvent + ) -> SSQReActSequentialQueryEvent | SSQReActReviewSubQuestionEvent: """ This step receives the sub-questions and asks the LLM to review them. If the LLM is satisfied with the sub-questions, they can be used to answer the original question. Otherwise, the LLM can provide updated @@ -350,16 +340,16 @@ async def review_sub_questions( Args: ctx (Context): The context object. - ev (DQAReviewSubQuestionEvent): The event containing the sub-questions. + ev (SSQReActReviewSubQuestionEvent): The event containing the sub-questions. Returns: - DQASequentialQueryEvent | DQAReviewSubQuestionEvent: The event containing the sub-question index to process or the event to review the + SSQReActSequentialQueryEvent | SSQReActReviewSubQuestionEvent: The event containing the sub-question index to process or the event to review the sub-questions. """ if ev.satisfied: # Already satisfied, no need to review anymore. - return DQASequentialQueryEvent() + return SSQReActSequentialQueryEvent() self._total_steps += 1 ctx.write_event_to_stream( @@ -387,25 +377,27 @@ async def review_sub_questions( "\n\nLastly, reflect on the amended sub-questions and generate a binary response indicating whether you are satisfied with the amended sub-questions or not." "\n\nAlways, respond in pure JSON without any Markdown, like this:" "{\n" - f' "{DQAWorkflow.KEY_SUB_QUESTIONS}": [\n' + f' "{StructuredSubQuestionReActWorkflow.KEY_SUB_QUESTIONS}": [\n' ' "sub question 1",\n' ' "sub question 2",\n' ' "sub question 3"\n' " ],\n" - f' "{DQAWorkflow.KEY_SATISFIED}": true or false\n' + f' "{StructuredSubQuestionReActWorkflow.KEY_SATISFIED}": true or false\n' "}" "\nDO NOT hallucinate!" - f"\n\nHere is the user question: {await ctx.get(DQAWorkflow.KEY_ORIGINAL_QUERY)}" + f"\n\nHere is the user question: {await ctx.get(StructuredSubQuestionReActWorkflow.KEY_ORIGINAL_QUERY)}" f"\n\nHere are the sub-questions for you to review:\n{ev.questions}" - f"\n\nAnd, here is the corresponding reasoning structure:\n{await ctx.get(DQAWorkflow.KEY_REASONING_STRUCTURE)}" + f"\n\nAnd, here is the corresponding reasoning structure:\n{await ctx.get(StructuredSubQuestionReActWorkflow.KEY_REASONING_STRUCTURE)}" ) response = await self.llm.acomplete(prompt) self._finished_steps += 1 self._refinement_iterations += 1 response_obj = json.loads(str(response)) - sub_questions = response_obj[DQAWorkflow.KEY_SUB_QUESTIONS] - satisfied = response_obj[DQAWorkflow.KEY_SATISFIED] + sub_questions = response_obj[ + StructuredSubQuestionReActWorkflow.KEY_SUB_QUESTIONS + ] + satisfied = response_obj[StructuredSubQuestionReActWorkflow.KEY_SATISFIED] ctx.write_event_to_stream( WorkflowStatusEvent( @@ -415,44 +407,53 @@ async def review_sub_questions( ) ) - await ctx.set(DQAWorkflow.KEY_SUB_QUESTIONS, sub_questions) - await ctx.set(DQAWorkflow.KEY_SUB_QUESTIONS_COUNT, len(sub_questions)) + await ctx.set( + StructuredSubQuestionReActWorkflow.KEY_SUB_QUESTIONS, sub_questions + ) + await ctx.set( + StructuredSubQuestionReActWorkflow.KEY_SUB_QUESTIONS_COUNT, + len(sub_questions), + ) # Ignore the satisfied flag if there is only one sub-question. if len(sub_questions) == 1: - return DQASequentialQueryEvent() + return SSQReActSequentialQueryEvent() if satisfied or self._refinement_iterations >= self._max_refinement_iterations: - return DQASequentialQueryEvent() + return SSQReActSequentialQueryEvent() else: - return DQAReviewSubQuestionEvent( + return SSQReActReviewSubQuestionEvent( questions=sub_questions, satisfied=satisfied ) @step async def answer_sub_question( - self, ctx: Context, ev: DQASequentialQueryEvent - ) -> DQAAnswerEvent | DQASequentialQueryEvent: + self, ctx: Context, ev: SSQReActSequentialQueryEvent + ) -> SSQReActAnswerEvent | SSQReActSequentialQueryEvent: """ This step receives a sub-question and attempts to answer it using the tools provided in the context. Args: ctx (Context): The context object. - ev (DQASequentialQueryEvent): The event containing the sub-question index to process. + ev (SSQReActSequentialQueryEvent): The event containing the sub-question index to process. Returns: - DQAAnswerEvent: The event containing the sub-question and the answer. + SSQReActAnswerEvent: The event containing the sub-question and the answer. """ - if isinstance(ev, DQASequentialQueryEvent): - sub_questions = await ctx.get(DQAWorkflow.KEY_SUB_QUESTIONS) + if isinstance(ev, SSQReActSequentialQueryEvent): + sub_questions = await ctx.get( + StructuredSubQuestionReActWorkflow.KEY_SUB_QUESTIONS + ) if not sub_questions or len(sub_questions) == 0: raise ValueError("No questions to answer.") question = sub_questions[ev.question_index] else: question = ev.question - react_context = await ctx.get(DQAWorkflow.KEY_REACT_CONTEXT) + react_context = await ctx.get( + StructuredSubQuestionReActWorkflow.KEY_REACT_CONTEXT + ) self._total_steps += 1 ctx.write_event_to_stream( @@ -468,7 +469,7 @@ async def answer_sub_question( react_workflow = ReActWorkflow( llm=self.llm, tools=self.tools, - # Let's set the timeout of the ReAct workflow to half of the DQA workflow's timeout. + # Let's set the timeout of the ReAct workflow to half of the SSQReAct workflow's timeout. timeout=self._timeout / 2, verbose=self._verbose, # Let's keep the maximum iterations of the ReAct workflow to its default value. @@ -493,7 +494,7 @@ async def answer_sub_question( response = react_task.result() self._finished_steps += 1 - react_answer_event = DQAAnswerEvent( + react_answer_event = SSQReActAnswerEvent( question=question, answer=response[ReActWorkflow.KEY_RESPONSE], # TODO: Should we format the sources nicely here so that the LLM does not have to deal with it later? @@ -510,20 +511,24 @@ async def answer_sub_question( # f"> Sources: {', '.join(react_answer_event.sources)}" ) - await ctx.set(DQAWorkflow.KEY_REACT_CONTEXT, react_context) + await ctx.set( + StructuredSubQuestionReActWorkflow.KEY_REACT_CONTEXT, react_context + ) ctx.send_event(react_answer_event) - if isinstance(ev, DQASequentialQueryEvent): + if isinstance(ev, SSQReActSequentialQueryEvent): if ev.question_index + 1 < len(sub_questions): # Let's move to the next sub-question. - return DQASequentialQueryEvent(question_index=ev.question_index + 1) + return SSQReActSequentialQueryEvent( + question_index=ev.question_index + 1 + ) return None @step async def combine_refine_answers( - self, ctx: Context, ev: DQAAnswerEvent + self, ctx: Context, ev: SSQReActAnswerEvent ) -> StopEvent | None: """ This step receives the answers to the sub-questions and combines them into a single answer to the original @@ -532,14 +537,16 @@ async def combine_refine_answers( Args: ctx (Context): The context object. - ev (DQAAnswerEvent): The event containing the sub-question and the answer. + ev (SSQReActAnswerEvent): The event containing the sub-question and the answer. Returns: StopEvent | None: The event containing the final answer to the original question, or None if the sub-questions have not all been answered. """ ready = ctx.collect_events( - ev, [DQAAnswerEvent] * await ctx.get(DQAWorkflow.KEY_SUB_QUESTIONS_COUNT) + ev, + [SSQReActAnswerEvent] + * await ctx.get(StructuredSubQuestionReActWorkflow.KEY_SUB_QUESTIONS_COUNT), ) if ready is None: return None @@ -558,7 +565,7 @@ async def combine_refine_answers( self._total_steps += 1 ctx.write_event_to_stream( WorkflowStatusEvent( - msg=f"Generating the final response to the original query:\n\t{await ctx.get(DQAWorkflow.KEY_ORIGINAL_QUERY)}", + msg=f"Generating the final response to the original query:\n\t{await ctx.get(StructuredSubQuestionReActWorkflow.KEY_ORIGINAL_QUERY)}", total_steps=self._total_steps, finished_steps=self._finished_steps, ) @@ -578,9 +585,9 @@ async def combine_refine_answers( "In your final answer, cite each source and its corresponding URLs, only if such source URLs are available are in the answers to the sub-questions." "\nDo not make up sources or URLs if they are not present in the answers to the sub-questions. " "\nYour final answer must be correctly formatted as pure HTML (with no Javascript and Markdown) in a concise, readable and visually pleasing way. " - "Enclose your HTML response with a
tag that has an attribute `id` set to the value 'dqa_workflow_response'." + "Enclose your HTML response with a
tag that has an attribute `id` set to the value 'workflow_response'." "\nDO NOT hallucinate!" - f"\n\nOriginal question: {await ctx.get(DQAWorkflow.KEY_ORIGINAL_QUERY)}" + f"\n\nOriginal question: {await ctx.get(StructuredSubQuestionReActWorkflow.KEY_ORIGINAL_QUERY)}" f"\n\nSub-questions, answers and relevant sources:\n{answers}" ) @@ -595,313 +602,3 @@ async def combine_refine_answers( ) ) return StopEvent(result=str(response)) - - -class DQAEngine: - """The Difficult Questions Attempted engine.""" - - def __init__(self, llm: LLM | None = None): - """ - Initialize the Difficult Questions Attempted engine. - - Args: - llm (LLM): The function calling LLM instance to use. - """ - self.llm = llm - # Add tool specs - self.tools: List[FunctionTool] = [] - # Mandatory tools - self.tools.extend(StringFunctionsToolSpec().to_tool_list()) - self.tools.extend(BasicArithmeticCalculatorSpec().to_tool_list()) - - # TODO: Populate the tools based on toolset names specified in the environment variables? - self.tools.extend(DuckDuckGoFullSearchOnlyToolSpec().to_tool_list()) - - def _are_tools_present(self, tool_names: list[str]) -> bool: - """ - Check if the tools with the given names are present in the current set of tools. - - Args: - tool_names (list[str]): The names of the tools to check. - - Returns: - bool: True if all the tools are present, False otherwise. - """ - return all( - tool_name in [tool.metadata.name for tool in self.tools] - for tool_name in tool_names - ) - - def _remove_tools_by_names(self, tool_names: list[str]): - """ - Remove the tools with the given names from the current set of tools. - - Args: - tool_names (list[str]): The names of the tools to remove. - """ - self.tools = [ - tool for tool in self.tools if tool.metadata.name not in tool_names - ] - - def is_toolset_present(self, toolset_name: str) -> bool: - """ - Check if the tools for the given toolset are present in the current set of tools. - - Args: - toolset_name (str): The name of the toolset to check. - - Returns: - bool: True if the tools are present, False otherwise. - """ - if toolset_name == ToolNames.TOOL_NAME_ARXIV: - return self._are_tools_present( - [tool.metadata.name for tool in ArxivToolSpec().to_tool_list()] - ) - elif toolset_name == ToolNames.TOOL_NAME_BASIC_ARITHMETIC_CALCULATOR: - return self._are_tools_present( - [ - tool.metadata.name - for tool in BasicArithmeticCalculatorSpec().to_tool_list() - ] - ) - elif toolset_name == ToolNames.TOOL_NAME_MATHEMATICAL_FUNCTIONS: - return self._are_tools_present( - [ - tool.metadata.name - for tool in MathematicalFunctionsSpec().to_tool_list() - ] - ) - elif toolset_name == ToolNames.TOOL_NAME_DUCKDUCKGO: - return self._are_tools_present( - [ - tool.metadata.name - for tool in DuckDuckGoFullSearchOnlyToolSpec().to_tool_list() - ] - ) - elif toolset_name == ToolNames.TOOL_NAME_STRING_FUNCTIONS: - return self._are_tools_present( - [ - tool.metadata.name - for tool in StringFunctionsToolSpec().to_tool_list() - ] - ) - elif toolset_name == ToolNames.TOOL_NAME_TAVILY: - return self._are_tools_present( - [ - tool.metadata.name - for tool in TavilyToolSpec(api_key=FAKE_STRING).to_tool_list() - ] - ) - elif toolset_name == ToolNames.TOOL_NAME_WIKIPEDIA: - return self._are_tools_present( - [tool.metadata.name for tool in WikipediaToolSpec().to_tool_list()] - ) - elif toolset_name == ToolNames.TOOL_NAME_YAHOO_FINANCE: - return self._are_tools_present( - [tool.metadata.name for tool in YahooFinanceToolSpec().to_tool_list()] - ) - - def get_selected_web_search_toolset(self) -> str: - """ - Get the name of the web search toolset currently selected. - - Returns: - str: The name of the web search toolset. - """ - if self.is_toolset_present(ToolNames.TOOL_NAME_DUCKDUCKGO): - return ToolNames.TOOL_NAME_DUCKDUCKGO - elif self.is_toolset_present(ToolNames.TOOL_NAME_TAVILY): - return ToolNames.TOOL_NAME_TAVILY - else: - # Unknown or no toolset selected. - return ToolNames.TOOL_NAME_SELECTION_DISABLE - - def remove_toolset(self, toolset_name: str): - """ - Remove the tools for the given toolset from the current set of tools. - - Args: - toolset_name (str): The name of the toolset to remove. - """ - if toolset_name == ToolNames.TOOL_NAME_ARXIV: - self._remove_tools_by_names( - [tool.metadata.name for tool in ArxivToolSpec().to_tool_list()] - ) - elif toolset_name == ToolNames.TOOL_NAME_BASIC_ARITHMETIC_CALCULATOR: - self._remove_tools_by_names( - [ - tool.metadata.name - for tool in BasicArithmeticCalculatorSpec().to_tool_list() - ] - ) - elif toolset_name == ToolNames.TOOL_NAME_MATHEMATICAL_FUNCTIONS: - self._remove_tools_by_names( - [ - tool.metadata.name - for tool in MathematicalFunctionsSpec().to_tool_list() - ] - ) - elif toolset_name == ToolNames.TOOL_NAME_DUCKDUCKGO: - self._remove_tools_by_names( - [ - tool.metadata.name - for tool in DuckDuckGoFullSearchOnlyToolSpec().to_tool_list() - ] - ) - elif toolset_name == ToolNames.TOOL_NAME_STRING_FUNCTIONS: - self._remove_tools_by_names( - [ - tool.metadata.name - for tool in StringFunctionsToolSpec().to_tool_list() - ] - ) - elif toolset_name == ToolNames.TOOL_NAME_TAVILY: - self._remove_tools_by_names( - [ - tool.metadata.name - for tool in TavilyToolSpec(api_key=FAKE_STRING).to_tool_list() - ] - ) - elif toolset_name == ToolNames.TOOL_NAME_WIKIPEDIA: - self._remove_tools_by_names( - [tool.metadata.name for tool in WikipediaToolSpec().to_tool_list()] - ) - elif toolset_name == ToolNames.TOOL_NAME_YAHOO_FINANCE: - self._remove_tools_by_names( - [tool.metadata.name for tool in YahooFinanceToolSpec().to_tool_list()] - ) - - def add_or_set_toolset( - self, - toolset_name: str, - api_key: str | None = None, - remove_existing: bool = True, - ): - """ - Add or set the tools for the given toolset. - - Args: - toolset_name (str): The name of the toolset to add or set. - api_key (str): The API key to use with the toolset. - remove_existing (bool): Whether to remove the existing tools for the toolset before adding the new ones. - """ - - # Remove the existing tools for the toolset to avoid duplicates - if remove_existing: - self.remove_toolset(toolset_name) - - if toolset_name == ToolNames.TOOL_NAME_ARXIV: - self.tools.extend(ArxivToolSpec().to_tool_list()) - elif toolset_name == ToolNames.TOOL_NAME_BASIC_ARITHMETIC_CALCULATOR: - self.tools.extend(BasicArithmeticCalculatorSpec().to_tool_list()) - elif toolset_name == ToolNames.TOOL_NAME_MATHEMATICAL_FUNCTIONS: - self.tools.extend(MathematicalFunctionsSpec().to_tool_list()) - elif toolset_name == ToolNames.TOOL_NAME_DUCKDUCKGO: - self.tools.extend(DuckDuckGoFullSearchOnlyToolSpec().to_tool_list()) - elif toolset_name == ToolNames.TOOL_NAME_STRING_FUNCTIONS: - self.tools.extend(StringFunctionsToolSpec().to_tool_list()) - elif toolset_name == ToolNames.TOOL_NAME_TAVILY: - self.tools.extend(TavilyToolSpec(api_key=api_key).to_tool_list()) - elif toolset_name == ToolNames.TOOL_NAME_WIKIPEDIA: - self.tools.extend(WikipediaToolSpec().to_tool_list()) - elif toolset_name == ToolNames.TOOL_NAME_YAHOO_FINANCE: - self.tools.extend(YahooFinanceToolSpec().to_tool_list()) - - def set_web_search_tool( - self, search_tool: str, search_tool_api_key: str | None = None - ): - """ - Set the web search tool to use for the Difficult Questions Attempted engine. - - Args: - search_tool (str): The web search tool to use. - api_key (str): The API key to use with the web search tool. - """ - - self.remove_toolset(ToolNames.TOOL_NAME_DUCKDUCKGO) - self.remove_toolset(ToolNames.TOOL_NAME_TAVILY) - - if search_tool != ToolNames.TOOL_NAME_SELECTION_DISABLE: - self.add_or_set_toolset( - search_tool, api_key=search_tool_api_key, remove_existing=False - ) - - def get_descriptive_tools_dataframe(self): - """ - Get a dataframe consisting of the names and descriptions of the tools currently available. - """ - return [ - [ - f"`{tool.metadata.name}`", - tool.metadata.description.split("\n\n")[1].strip(), - ] - for tool in self.tools - ] - - async def run(self, query: str): - """ - Run the Difficult Questions Attempted engine with the given query. - - Args: - query (str): The query to process. - - Yields: - tuple: A tuple containing the done status, the number of finished steps, the total number of steps, and the message - for each step of the workflow. The message is the response to the query when the workflow is done. - """ - # Instantiating the ReAct workflow instead may not be always enough to get the desired responses to certain questions. - self.workflow = DQAWorkflow( - llm=self.llm, - tools=self.tools, - timeout=180, - verbose=False, - ) - # No need for this, see: https://github.com/run-llama/llama_index/discussions/15838#discussioncomment-10553154 - # self.workflow.add_workflows( - # react_workflow=ReActWorkflow( - # llm=self.llm, tools=self.tools, timeout=60, verbose=True - # ) - # ) - # No longer usable in this way, due to breaking changes in LlamaIndex Workflows. - # task = asyncio.create_task( - # self.workflow.run( - # query=query, - # ) - # ) - task: asyncio.Future = self.workflow.run( - query=query, - ) - done: bool = False - total_steps: int = 0 - finished_steps: int = 0 - terminal_columns, _ = get_terminal_size() - progress_bar = tqdm( - total=total_steps, - leave=False, - unit="step", - ncols=int(terminal_columns / 2), - desc=APP_TITLE_SHORT, - colour="yellow", - ) - async for ev in self.workflow.stream_events(): - total_steps = ev.total_steps - finished_steps = ev.finished_steps - print(f"\n{str(ev.msg)}", flush=True) - # TODO: Is tqdm.write better than printf? - # tqdm.write(f"\n{str(ev.msg)}") - progress_bar.reset(total=total_steps) - progress_bar.update(finished_steps) - progress_bar.refresh() - yield done, finished_steps, total_steps, ev.msg - try: - done, _ = await asyncio.wait([task]) - if done: - result = task.result() - except Exception as e: - result = f"\nException in running the workflow(s). Type: {type(e).__name__}. Message: '{str(e)}'" - # Set this to done, otherwise another workflow call cannot be made. - done = True - print(result, file=sys.stderr) - finally: - progress_bar.close() - yield done, finished_steps, total_steps, result