diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py index 59731d67..f2ccb5fe 100644 --- a/paperqa/agents/task.py +++ b/paperqa/agents/task.py @@ -175,15 +175,20 @@ def compute_trajectory_metrics( evidence_count: list[float] = [] for t in trajectories: split_answers = [ - re.split( - pattern=GenerateAnswer.ANSWER_SPLIT_REGEX_PATTERN, - string=obs.content, - ) - for obs in t.steps[-1].next_observation - if ( - isinstance(obs, ToolResponseMessage) - and obs.name == GenerateAnswer.TOOL_FN_NAME + split_answers + for split_answers in ( + re.split( + pattern=GenerateAnswer.ANSWER_SPLIT_REGEX_PATTERN, + string=obs.content, + ) + for obs in t.steps[-1].next_observation + if ( + isinstance(obs, ToolResponseMessage) + and obs.name == GenerateAnswer.TOOL_FN_NAME + ) ) + # Filter for places where the regex split succeeded + if len(split_answers) >= 4 # noqa: PLR2004 ] for i, metric_list in enumerate( (total_paper_count, relevant_paper_count, evidence_count),