From 22bd7ccf8df05589c5fdebec5ec84ea1cbca4a47 Mon Sep 17 00:00:00 2001 From: James Braza Date: Fri, 27 Sep 2024 00:51:59 -0700 Subject: [PATCH] Fixed task deserialization from config, with a test --- paperqa/agents/task.py | 16 ++++++++++++---- tests/test_task.py | 23 ++++++++++++++++++++--- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py index 964e2084..59731d67 100644 --- a/paperqa/agents/task.py +++ b/paperqa/agents/task.py @@ -123,14 +123,22 @@ class LitQATaskDataset( def __init__( self, - base_query: QueryRequest | None = None, - base_docs: Docs | None = None, + base_query: QueryRequest | dict | None = None, + base_docs: Docs | dict | None = None, rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION, eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME, **env_kwargs, ): - self._base_query = base_query or QueryRequest() - self._base_docs = base_docs or Docs() + if base_query is None: + base_query = QueryRequest() + if isinstance(base_query, dict): + base_query = QueryRequest(**base_query) + self._base_query = base_query + if base_docs is None: + base_docs = Docs() + if isinstance(base_docs, dict): + base_docs = Docs(**base_docs) + self._base_docs = base_docs self._rewards = rewards self._env_kwargs = env_kwargs self._eval_model = eval_model diff --git a/tests/test_task.py b/tests/test_task.py index 24ab5054..d9aec241 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from aviary.env import TASK_DATASET_REGISTRY, TaskDataset +from aviary.env import TASK_DATASET_REGISTRY, TaskConfig, TaskDataset from ldp.agent import SimpleAgent from ldp.alg.callbacks import MeanMetricsCallback from ldp.alg.runners import Evaluator, EvaluatorConfig @@ -78,9 +78,26 @@ def test___len__( @pytest.mark.asyncio async def test_evaluation(self, base_query_request: QueryRequest) -> None: docs = Docs() - dataset = TaskDataset.from_name( - STUB_TASK_DATASET_NAME, base_query=base_query_request, base_docs=docs + # Why are we constructing a TaskConfig here using a serialized QueryRequest and + # Docs? It's to confirm everything works as if hydrating from a YAML config file + task_config = TaskConfig( + name=STUB_TASK_DATASET_NAME, + eval_kwargs={ + "base_query": base_query_request.model_dump( + exclude={"id", "settings", "docs_name"} + ), + "base_docs": docs.model_dump( + exclude={ + "id", + "docnames", + "texts_index", + "index_path", + "deleted_dockeys", + } + ), + }, ) + dataset = task_config.make_dataset(split="eval") # noqa: FURB184 metrics_callback = MeanMetricsCallback(eval_dataset=dataset) evaluator = Evaluator(