Skip to content

Commit

Permalink
Moved GradablePaperQAEnvironment.step to use the answer, removing Gen…
Browse files Browse the repository at this point in the history
…erateAnswer parsing complexity
  • Loading branch information
jamesbraza committed Nov 15, 2024
1 parent e71fc64 commit cccdaff
Showing 1 changed file with 1 addition and 25 deletions.
26 changes: 1 addition & 25 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,31 +130,7 @@ async def step(
messages, reward, done, truncated = await super().step(action)
if not done or not self._evaluation_from_answer:
return messages, reward, done, truncated
valid_answers, failed_answer_messages = [], []
for m in messages:
if (
not isinstance(m, ToolResponseMessage)
or m.name != GenerateAnswer.gen_answer.__name__
):
continue # Filter out non-answer messages (in case parallel tool calls)
if answer := GenerateAnswer.extract_answer_from_message(content=m.content):
valid_answers.append(answer)
else:
failed_answer_messages.append(m)
if not valid_answers: # No answer, so no positive reward
return messages, reward, done, truncated
if len(valid_answers) != 1:
raise NotImplementedError(
f"Expected just one answer message, got more than one in {messages}."
)
answer = valid_answers[0]
if failed_answer_messages:
logger.warning(
"More than one answer detected, discarding failed answer messages"
f" {failed_answer_messages}, continuing with answer {answer}."
)
# Okay, so we have one answer that was not a failed answer. Let's evaluate it
evaluation = await self._evaluation_from_answer(answer)
evaluation = await self._evaluation_from_answer(self.state.session.answer)
if evaluation_callback := self._evaluation_callback:
await evaluation_callback(evaluation)
return messages, reward + self._rewards[evaluation.value], done, truncated
Expand Down

0 comments on commit cccdaff

Please sign in to comment.