From bcaf43d91b43735a6c5ac45eb29dba3a15e8e291 Mon Sep 17 00:00:00 2001 From: James Braza Date: Sat, 2 Nov 2024 15:29:36 -0700 Subject: [PATCH] Ability to zero-shot `gen_answer` (#658) --- .pre-commit-config.yaml | 2 +- paperqa/docs.py | 3 +-- paperqa/settings.py | 7 +++++++ pyproject.toml | 2 +- tests/test_task.py | 31 +++++++++++++++++++++++++++++-- uv.lock | 4 ++-- 6 files changed, 41 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 31f37c1b5..502a5fbb7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -86,7 +86,7 @@ repos: - aiohttp - coredis - fhaviary[llm]>=0.8.2 # Match pyproject.toml - - ldp>=0.9 # Match pyproject.toml + - ldp>=0.12 # Match pyproject.toml - html2text - httpx - limits diff --git a/paperqa/docs.py b/paperqa/docs.py index ce95ef1bb..f10f5c15f 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -672,8 +672,7 @@ async def aquery( # noqa: PLR0912 ) contexts = session.contexts - - if not contexts: + if answer_config.get_evidence_if_no_contexts and not contexts: session = await self.aget_evidence( session, callbacks=callbacks, diff --git a/paperqa/settings.py b/paperqa/settings.py index f921b6efc..a1d80013c 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -93,6 +93,13 @@ class AnswerSettings(BaseModel): default=False, description="Whether to cite background information provided by model.", ) + get_evidence_if_no_contexts: bool = Field( + default=True, + description=( + "Opt-out flag for allowing answer generation to lazily gather evidence if" + " called before evidence was gathered." + ), + ) @model_validator(mode="after") def _deprecated_field(self) -> Self: diff --git a/pyproject.toml b/pyproject.toml index 4508a8be9..dd4d70858 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ datasets = [ "datasets", ] ldp = [ - "ldp>=0.9", # For alg namespace grouping + "ldp>=0.12", # For StoreTrajectoriesCallback ] local = [ "sentence-transformers", diff --git a/tests/test_task.py b/tests/test_task.py index cddecc12b..d628bc33a 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -5,8 +5,9 @@ import pytest from aviary.env import TASK_DATASET_REGISTRY, TaskConfig, TaskDataset from ldp.agent import SimpleAgent -from ldp.alg.callbacks import MeanMetricsCallback +from ldp.alg.callbacks import MeanMetricsCallback, StoreTrajectoriesCallback from ldp.alg.runners import Evaluator, EvaluatorConfig +from pytest_subtests import SubTests from paperqa import Docs, QueryRequest, Settings from paperqa.agents import get_directory_index @@ -16,6 +17,7 @@ LitQAv2TaskDataset, LitQAv2TaskSplit, ) +from paperqa.agents.tools import GenerateAnswer @pytest.fixture(name="base_query_request") @@ -106,7 +108,9 @@ async def test_can_validate_stub_dataset_sources( ) @pytest.mark.asyncio - async def test_evaluation(self, base_query_request: QueryRequest) -> None: + async def test_evaluation( + self, subtests: SubTests, base_query_request: QueryRequest + ) -> None: await get_directory_index(settings=base_query_request.settings) # Build docs = Docs() # Why are we constructing a TaskConfig here using a serialized QueryRequest and @@ -150,6 +154,29 @@ async def test_evaluation(self, base_query_request: QueryRequest) -> None: isinstance(metrics_callback.eval_means["reward"], float) > 0 ), "Expected some wins" + with subtests.test(msg="zero-shot"): + # Confirm we can just directly call gen_answer + base_query_request.settings.agent.tool_names = { + GenerateAnswer.gen_answer.__name__ + } + base_query_request.settings.answer.get_evidence_if_no_contexts = False + dataset = LitQAv2TaskDataset(base_query=base_query_request) + dataset.data = dataset.data[:2] # Save the world: just use two questions + storage_callback = StoreTrajectoriesCallback() + evaluator = Evaluator( + config=EvaluatorConfig(batch_size=len(dataset), max_rollout_steps=2), + agent=SimpleAgent(), + dataset=dataset, + callbacks=[storage_callback], + ) + await evaluator.evaluate() + for traj in storage_callback.eval_trajectories: + for step in traj.steps: + assert all( + tc.function.name == GenerateAnswer.gen_answer.__name__ + for tc in step.action.value.tool_calls + ) + @pytest.mark.vcr @pytest.mark.asyncio async def test_tool_failure(self, base_query_request: QueryRequest) -> None: diff --git a/uv.lock b/uv.lock index a38e1193d..9f2cbb626 100644 --- a/uv.lock +++ b/uv.lock @@ -1505,7 +1505,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.3.3.dev1+gdcb4fd4" +version = "5.3.3.dev5+gc849dd1" source = { editable = "." } dependencies = [ { name = "aiohttp" }, @@ -1584,7 +1584,7 @@ requires-dist = [ { name = "fhaviary", extras = ["llm"], specifier = ">=0.8.2" }, { name = "html2text" }, { name = "httpx" }, - { name = "ldp", marker = "extra == 'ldp'", specifier = ">=0.9" }, + { name = "ldp", marker = "extra == 'ldp'", specifier = ">=0.12" }, { name = "limits" }, { name = "litellm", specifier = ">=1.44" }, { name = "numpy" },