Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created complete tool to allow unsure answers #684

Merged
merged 5 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 11 additions & 21 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from paperqa.docs import Docs
from paperqa.llms import EmbeddingModel, LiteLLMModel
from paperqa.settings import Settings
from paperqa.types import PQASession, check_could_not_answer
from paperqa.types import PQASession
from paperqa.utils import get_year

from .models import QueryRequest
from .tools import (
AVAILABLE_TOOL_NAME_TO_CLASS,
Complete,
EnvironmentState,
GatherEvidence,
GenerateAnswer,
Expand All @@ -39,16 +40,16 @@ def settings_to_tools(
embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
) -> list[Tool]:
"""
Convert a Settings into tools, confirming the gen_answer tool is present.
Convert a Settings into tools, confirming the complete tool is present.

NOTE: the last element of the return will always be GenerateAnswer.
NOTE: the last element of the return will always be Complete.
"""
llm_model = llm_model or settings.get_llm()
summary_llm_model = summary_llm_model or settings.get_summary_llm()
embedding_model = embedding_model or settings.get_embedding_model()
tools: list[Tool] = []
for tool_type in (
(PaperSearch, GatherEvidence, GenerateAnswer)
(PaperSearch, GatherEvidence, GenerateAnswer, Complete)
if settings.agent.tool_names is None
else [
AVAILABLE_TOOL_NAME_TO_CLASS[name]
Expand Down Expand Up @@ -82,9 +83,11 @@ def settings_to_tools(
embedding_model=embedding_model,
).gen_answer
)
elif issubclass(tool_type, Complete):
tool = Tool.from_function(Complete().complete)
else:
raise NotImplementedError(f"Didn't handle tool type {tool_type}.")
if tool.info.name == GenerateAnswer.gen_answer.__name__:
if tool.info.name == Complete.complete.__name__:
tools.append(tool) # Place at the end
else:
tools.insert(0, tool)
Expand Down Expand Up @@ -142,7 +145,7 @@ async def reset(self) -> tuple[list[Message], list[Tool]]:
content=self._query.settings.agent.agent_prompt.format(
question=self.state.session.question,
status=self.state.status,
gen_answer_tool_name=GenerateAnswer.TOOL_FN_NAME,
complete_tool_name=Complete.TOOL_FN_NAME,
),
)
],
Expand Down Expand Up @@ -170,30 +173,17 @@ async def step(
self, action: ToolRequestMessage
) -> tuple[Messages, float, bool, bool]:
self.state.record_action(action)
if not action.tool_calls:
return (
# NOTE: don't put:
# - GenerateAnswer.FAILED_TO_ANSWER here because this wasn't a failure
# - 'cannot answer' because that information belongs in
# PQASession.answer, not in the message history
# Let's just put a nice message about being done :)
[Message(content="Agent specified 0 tool calls, which means done.")],
self.USE_POST_PROCESSED_REWARD,
True, # Matching LangChain: https://github.com/langchain-ai/langchain/blob/langchain%3D%3D0.2.17/libs/langchain/langchain/agents/output_parsers/openai_functions.py#L38-L77
False, # Let caller determine truncations
)

response_messages = cast(
list[Message],
await self.exec_tool_calls(action, state=self.state, handle_tool_exc=True),
)
) or [Message(content=f"No tool calls input in tool request {action}.")]
return (
response_messages,
self.USE_POST_PROCESSED_REWARD,
any(
isinstance(msg, ToolResponseMessage)
and msg.name == GenerateAnswer.gen_answer.__name__
and not check_could_not_answer(msg.content)
and msg.name == Complete.complete.__name__
for msg in response_messages
)
or self._has_excess_answer_failures(),
Expand Down
35 changes: 6 additions & 29 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef]

from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
from .models import QueryRequest
from .tools import GenerateAnswer
from .tools import Complete

if TYPE_CHECKING:
from ldp.data_structures import Trajectory
Expand Down Expand Up @@ -130,31 +130,7 @@ async def step(
messages, reward, done, truncated = await super().step(action)
if not done or not self._evaluation_from_answer:
return messages, reward, done, truncated
valid_answers, failed_answer_messages = [], []
for m in messages:
if (
not isinstance(m, ToolResponseMessage)
or m.name != GenerateAnswer.gen_answer.__name__
):
continue # Filter out non-answer messages (in case parallel tool calls)
if answer := GenerateAnswer.extract_answer_from_message(content=m.content):
valid_answers.append(answer)
else:
failed_answer_messages.append(m)
if not valid_answers: # No answer, so no positive reward
return messages, reward, done, truncated
if len(valid_answers) != 1:
raise NotImplementedError(
f"Expected just one answer message, got more than one in {messages}."
)
answer = valid_answers[0]
if failed_answer_messages:
logger.warning(
"More than one answer detected, discarding failed answer messages"
f" {failed_answer_messages}, continuing with answer {answer}."
)
# Okay, so we have one answer that was not a failed answer. Let's evaluate it
evaluation = await self._evaluation_from_answer(answer)
evaluation = await self._evaluation_from_answer(self.state.session.answer)
if evaluation_callback := self._evaluation_callback:
await evaluation_callback(evaluation)
return messages, reward + self._rewards[evaluation.value], done, truncated
Expand Down Expand Up @@ -266,13 +242,14 @@ def compute_trajectory_metrics(
split_answers
for split_answers in (
re.split(
pattern=GenerateAnswer.ANSWER_SPLIT_REGEX_PATTERN,
pattern=Complete.ANSWER_SPLIT_REGEX_PATTERN,
string=obs.content,
maxsplit=1,
)
for obs in t.steps[-1].next_observation
if (
isinstance(obs, ToolResponseMessage)
and obs.name == GenerateAnswer.TOOL_FN_NAME
and obs.name == Complete.TOOL_FN_NAME
)
)
# Filter for places where the regex split succeeded
Expand All @@ -284,7 +261,7 @@ def compute_trajectory_metrics(
):
metric_list.append( # Use mean to allow for multiple answers
sum(int(sa[i]) for sa in split_answers) / len(split_answers)
if split_answers # Avoid div0 (when no answer was made)
if split_answers # Avoid div0 (when complete wasn't called)
else 0
)
return super().compute_trajectory_metrics(trajectories) | {
Expand Down
23 changes: 23 additions & 0 deletions paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ async def gen_answer(self, state: EnvironmentState) -> str:

return f"{answer} | {status}"

# Use to separate answer from status
# NOTE: can match failure to answer or an actual answer
ANSWER_SPLIT_REGEX_PATTERN: ClassVar[str] = (
r" \| " + EnvironmentState.STATUS_SEARCH_REGEX_PATTERN
Expand All @@ -340,6 +341,28 @@ def extract_answer_from_message(cls, content: str) -> str:
return answer


class Complete(NamedTool):
TOOL_FN_NAME = "complete"

# Use to separate answer from status
ANSWER_SPLIT_REGEX_PATTERN: ClassVar[str] = (
r" \| " + EnvironmentState.STATUS_SEARCH_REGEX_PATTERN
)

async def complete(self, state: EnvironmentState) -> str:
"""
Terminate using the last proposed answer.

Do not invoke this tool in parallel with other tools or itself.

Args:
state: Current state.
"""
logger.info(f"Completing '{state.session.question}'.")
# Return answer and status to simplify postprocessing of tool response
return f"{state.session.answer} | {state.status}"


AVAILABLE_TOOL_NAME_TO_CLASS: dict[str, type[NamedTool]] = {
cls.TOOL_FN_NAME: cls
for _, cls in inspect.getmembers(
Expand Down
2 changes: 1 addition & 1 deletion paperqa/configs/contracrow.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"agent_llm": "gpt-4o-2024-08-06",
"agent_type": "ToolSelector",
"agent_system_prompt": "You are a helpful AI assistant.",
"agent_prompt": "Answer question: {question}\n\nSearch for papers, gather evidence, collect papers cited in evidence then re-gather evidence, and answer. Gathering evidence will do nothing if you have not done a new search or collected new papers. If you do not have enough evidence to generate a good answer, you can:\n- Search for more papers (preferred)\n- Collect papers cited by previous evidence (preferred)\n- Gather more evidence using a different phrase\nIf you search for more papers or collect new papers cited by previous evidence, remember to gather evidence again. Once you have five or more pieces of evidence from multiple sources, or you have tried a few times, call {gen_answer_tool_name} tool. The {gen_answer_tool_name} tool output is visible to the user, so you do not need to restate the answer and can simply terminate if the answer looks sufficient. The current status of evidence/papers/cost is {status}",
"agent_prompt": "Answer question: {question}\n\nSearch for papers, gather evidence, collect papers cited in evidence then re-gather evidence, answer, and complete. Gathering evidence will do nothing if you have not done a new search or collected new papers. If you do not have enough evidence to generate a good answer, you can:\n- Search for more papers (preferred)\n- Collect papers cited by previous evidence (preferred)\n- Gather more evidence using a different phrase\nIf you search for more papers or collect new papers cited by previous evidence, remember to gather evidence again. Once you have five or more pieces of evidence from multiple sources, or you have tried a few times, call the {complete_tool_name} tool to terminate. The current status of evidence/papers/cost is {status}",
"search_count": 12,
"wipe_context_on_answer_failure": true,
"timeout": 500.0,
Expand Down
2 changes: 1 addition & 1 deletion paperqa/configs/wikicrow.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"agent_llm": "gpt-4-turbo-2024-04-09",
"agent_type": "ToolSelector",
"agent_system_prompt": "You are a helpful AI assistant.",
"agent_prompt": "Answer question: {question}\n\nSearch for papers, gather evidence, collect papers cited in evidence then re-gather evidence, and answer. Gathering evidence will do nothing if you have not done a new search or collected new papers. If you do not have enough evidence to generate a good answer, you can:\n- Search for more papers (preferred)\n- Collect papers cited by previous evidence (preferred)\n- Gather more evidence using a different phrase\nIf you search for more papers or collect new papers cited by previous evidence, remember to gather evidence again. Once you have five or more pieces of evidence from multiple sources, or you have tried a few times, call {gen_answer_tool_name} tool. The {gen_answer_tool_name} tool output is visible to the user, so you do not need to restate the answer and can simply terminate if the answer looks sufficient. The current status of evidence/papers/cost is {status}",
"agent_prompt": "Answer question: {question}\n\nSearch for papers, gather evidence, collect papers cited in evidence then re-gather evidence, answer, and complete. Gathering evidence will do nothing if you have not done a new search or collected new papers. If you do not have enough evidence to generate a good answer, you can:\n- Search for more papers (preferred)\n- Collect papers cited by previous evidence (preferred)\n- Gather more evidence using a different phrase\nIf you search for more papers or collect new papers cited by previous evidence, remember to gather evidence again. Once you have five or more pieces of evidence from multiple sources, or you have tried a few times, call the {complete_tool_name} tool to terminate. The current status of evidence/papers/cost is {status}",
"search_count": 12,
"wipe_context_on_answer_failure": true,
"timeout": 500.0,
Expand Down
7 changes: 3 additions & 4 deletions paperqa/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,11 @@
)
env_reset_prompt = (
"Use the tools to answer the question: {question}"
"\n\nThe {gen_answer_tool_name} tool output is visible to the user,"
" so you do not need to restate the answer"
" and can simply terminate if the answer looks sufficient."
"\n\nWhen the answer looks sufficient,"
" you can terminate by calling the {complete_tool_name} tool."
" If the answer does not look sufficient,"
" and you have already tried to answer several times,"
" you can terminate the question by specifying 0 tool calls."
" you can terminate by calling the {complete_tool_name} tool."
" The current status of evidence/papers/cost is {status}"
)

Expand Down
Loading