Skip to content

Commit

Permalink
Made key_filters part of Answer
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Jun 2, 2023
1 parent 2b97be3 commit d9c6472
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
19 changes: 15 additions & 4 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,12 @@ async def aget_evidence(
) -> Answer:
if len(self.docs) == 0:
return answer
if key_filter is not None:
answer.key_filter = key_filter
if self._faiss_index is None:
self._build_faiss_index()
_k = k
if key_filter is not None:
if answer.key_filter is not None:
_k = k * 10 # heuristic
# want to work through indices but less k
if marginal_relevance:
Expand All @@ -379,8 +381,17 @@ async def aget_evidence(
answer.question, k=_k, fetch_k=5 * _k
)
# ok now filter
if key_filter is not None:
docs = [doc for doc in docs if doc.metadata["dockey"] in key_filter][:k]
if answer.key_filter is not None:
# I realize that by testing for existence
# in strings that weird cases can
# happen - like FooBar can match FooBar2023
# but remember there are later checks
# The risk of explicitly parsing is that the
# language model may not give back in predictable
# format (e.g., - "FooBar and BarSoo are good papers")
docs = [doc for doc in docs if doc.metadata["dockey"] in answer.key_filter][
:k
]

async def process(doc):
if doc.metadata["dockey"] in self._deleted_keys:
Expand Down Expand Up @@ -524,12 +535,12 @@ async def aquery(
answer.tokens += callbacks[0].total_tokens
answer.cost += callbacks[0].total_cost
key_filter = True if len(keys) > 0 else False
answer.key_filter = keys
answer = await self.aget_evidence(
answer,
k=k,
max_sources=max_sources,
marginal_relevance=marginal_relevance,
key_filter=keys if key_filter else None,
get_callbacks=get_callbacks,
)
context_str, contexts = answer.context, answer.contexts
Expand Down
3 changes: 2 additions & 1 deletion paperqa/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, Optional

StrPath = Union[str, Path]

Expand Down Expand Up @@ -33,6 +33,7 @@ class Answer:
passages: Dict[str, str] = None
tokens: int = 0
cost: float = 0
key_filter: Optional[str] = None

def __post_init__(self):
"""Initialize the answer."""
Expand Down
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.0.2"
__version__ = "2.0.3"

0 comments on commit d9c6472

Please sign in to comment.