diff --git a/tests/test_task.py b/tests/test_task.py index cddecc12..a64c253f 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -5,8 +5,9 @@ import pytest from aviary.env import TASK_DATASET_REGISTRY, TaskConfig, TaskDataset from ldp.agent import SimpleAgent -from ldp.alg.callbacks import MeanMetricsCallback +from ldp.alg.callbacks import MeanMetricsCallback, StoreTrajectoriesCallback from ldp.alg.runners import Evaluator, EvaluatorConfig +from pytest_subtests import SubTests from paperqa import Docs, QueryRequest, Settings from paperqa.agents import get_directory_index @@ -16,6 +17,7 @@ LitQAv2TaskDataset, LitQAv2TaskSplit, ) +from paperqa.agents.tools import GenerateAnswer @pytest.fixture(name="base_query_request") @@ -106,7 +108,9 @@ async def test_can_validate_stub_dataset_sources( ) @pytest.mark.asyncio - async def test_evaluation(self, base_query_request: QueryRequest) -> None: + async def test_evaluation( + self, subtests: SubTests, base_query_request: QueryRequest + ) -> None: await get_directory_index(settings=base_query_request.settings) # Build docs = Docs() # Why are we constructing a TaskConfig here using a serialized QueryRequest and @@ -150,6 +154,31 @@ async def test_evaluation(self, base_query_request: QueryRequest) -> None: isinstance(metrics_callback.eval_means["reward"], float) > 0 ), "Expected some wins" + with subtests.test(msg="zero-shot"): + # Confirm we can just directly call gen_answer + base_query_request.settings.agent.tool_names = { + GenerateAnswer.gen_answer.__name__ + } + base_query_request.settings.answer.get_evidence_if_no_contexts = False + dataset = LitQAv2TaskDataset( + base_query=base_query_request, split=LitQAv2TaskSplit.EVAL + ) + dataset.data = dataset.data[:2] # Save the world: just use two questions + storage_callback = StoreTrajectoriesCallback() + evaluator = Evaluator( + config=EvaluatorConfig(batch_size=len(dataset), max_rollout_steps=2), + agent=SimpleAgent(), + dataset=dataset, + callbacks=[storage_callback], + ) + await evaluator.evaluate() + for traj in storage_callback.eval_trajectories: + for step in traj.steps: + assert all( + tc.function.name == GenerateAnswer.gen_answer.__name__ + for tc in step.action.value.tool_calls + ) + @pytest.mark.vcr @pytest.mark.asyncio async def test_tool_failure(self, base_query_request: QueryRequest) -> None: