Skip to content

Commit

Permalink
Alow failure on evidence gathering without crashing (#129)
Browse files Browse the repository at this point in the history
* Alow failure on evidence gathering without crashing

* Fixed errors on key filters
  • Loading branch information
whitead committed May 30, 2023
1 parent bda7bef commit 88d50e8
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 39 deletions.
41 changes: 29 additions & 12 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
search_prompt,
select_paper_prompt,
summary_prompt,
get_score,
)
from .readers import read_doc
from .types import Answer, Context
from .utils import maybe_is_text, md5sum, gather_with_concurrency
from .utils import maybe_is_text, md5sum, gather_with_concurrency, guess_is_4xx

os.makedirs(os.path.dirname(CACHE_PATH), exist_ok=True)
langchain.llm_cache = SQLiteCache(CACHE_PATH)
Expand Down Expand Up @@ -373,31 +374,46 @@ async def aget_evidence(
docs = self._faiss_index.similarity_search(
answer.question, k=_k, fetch_k=5 * _k
)
# ok now filter
if key_filter is not None:
docs = [doc for doc in docs if doc.metadata["dockey"] in key_filter][:k]

async def process(doc):
if doc.metadata["dockey"] in self._deleted_keys:
return None, None
if key_filter is not None and doc.metadata["dockey"] not in key_filter:
return None, None
# check if it is already in answer (possible in agent setting)
if doc.metadata["key"] in [c.key for c in answer.contexts]:
return None, None
callbacks = [OpenAICallbackHandler()] + get_callbacks(
"evidence:" + doc.metadata["key"]
)
summary_chain = make_chain(summary_prompt, self.summary_llm)
c = Context(
key=doc.metadata["key"],
citation=doc.metadata["citation"],
context=await summary_chain.arun(
# This is dangerous because it
# could mask errors that are important
# I also cannot know what the exception
# type is because any model could be used
# my best idea is see if there is a 4XX
# http code in the exception
try:
context = await summary_chain.arun(
question=answer.question,
context_str=doc.page_content,
citation=doc.metadata["citation"],
callbacks=callbacks,
),
)
except Exception as e:
if guess_is_4xx(e):
return None, None
raise e
c = Context(
key=doc.metadata["key"],
citation=doc.metadata["citation"],
context=context,
text=doc.page_content,
score=get_score(context),
)
if "not applicable" not in c.context.casefold():
print(c.score)
return c, callbacks[0]
return None, None

Expand All @@ -411,7 +427,7 @@ async def process(doc):
contexts = [c for c, _ in results if c is not None]
if len(contexts) == 0:
return answer
contexts = sorted(contexts, key=lambda x: len(x.context), reverse=True)
contexts = sorted(contexts, key=lambda x: x.score, reverse=True)
contexts = contexts[:max_sources]
# add to answer (if not already there)
keys = [c.key for c in answer.contexts]
Expand Down Expand Up @@ -499,11 +515,12 @@ async def aquery(
if answer is None:
answer = Answer(query)
if len(answer.contexts) == 0:
if key_filter or (key_filter is None and len(self.docs) > 5):
if key_filter or (key_filter is None and len(self.docs) > max_sources):
callbacks = [OpenAICallbackHandler()] + get_callbacks("filter")
keys = await self.adoc_match(answer.question, callbacks=callbacks)
answer.tokens += callbacks[0].total_tokens
answer.cost += callbacks[0].total_cost
key_filter = True if len(keys) > 0 else False
answer = await self.aget_evidence(
answer,
k=k,
Expand Down Expand Up @@ -532,8 +549,8 @@ async def aquery(
answer.tokens += cb.total_tokens
answer.cost += cb.total_cost
# it still happens lol
if "(Foo2012)" in answer_text:
answer_text = answer_text.replace("(Foo2012)", "")
if "(Example2012)" in answer_text:
answer_text = answer_text.replace("(Example2012)", "")
for c in contexts:
key = c.key
text = c.context
Expand Down
31 changes: 20 additions & 11 deletions paperqa/qaprompts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
from datetime import datetime
from typing import Any, Dict, List, Optional
import re

import langchain.prompts as prompts
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
Expand All @@ -11,27 +12,26 @@

summary_prompt = prompts.PromptTemplate(
input_variables=["question", "context_str", "citation"],
template="Summarize and provide direct quotes from the text below to help answer a question. "
"Do not directly answer the question, instead summarize and "
"quote to give evidence to help answer the question. "
"Do not use outside sources. "
'Reply with only "Not applicable" if the text is unrelated to the question. '
"Use 100 or less words."
template="Summarize the text below to help answer a question. "
"Do not directly answer the question, instead summarize "
"to give evidence to help answer the question. Include direct quotes. "
'Reply "Not applicable" if text is irrelevant. '
"Use around 100 words. At the end of your response, provide a score from 1-10 on a newline "
"indicating relevance to question. Do not explain your score. "
"\n\n"
"{context_str}\n"
"Extracted from {citation}\n"
"Question: {question}\n"
"Relevant Information Summary:",
)


qa_prompt = prompts.PromptTemplate(
input_variables=["question", "context_str", "length"],
template="Write an answer ({length}) "
"for the question below based on the provided context. "
"If the context provides insufficient information, "
'reply "I cannot answer". '
"For each sentence in your answer, indicate which sources most support it "
"For each part of your answer, indicate which sources most support it "
"via valid citation markers at the end of sentences, like (Example2012). "
"Answer in an unbiased, comprehensive, and scholarly tone. "
"If the question is subjective, provide an opinionated answer in the concluding 1-2 sentences. "
Expand Down Expand Up @@ -98,12 +98,21 @@ async def agenerate(
def make_chain(prompt, llm):
if type(llm) == ChatOpenAI:
system_message_prompt = SystemMessage(
content="You are a scholarly researcher that answers in an unbiased, concise, scholarly tone. "
"You sometimes refuse to answer if there is insufficient information. "
"If there are potentially ambiguous terms or acronyms, first define them. ",
content="Answer in an unbiased, concise, scholarly tone. "
"You may refuse to answer if there is insufficient information. "
"If there are ambiguous terms or acronyms, first define them. ",
)
human_message_prompt = HumanMessagePromptTemplate(prompt=prompt)
prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)
return FallbackLLMChain(prompt=prompt, llm=llm)


def get_score(text):
score = re.search(r"[sS]core[:is\s]+([0-9]+)", text)
if score:
return int(score.group(1))
if len(text) < 100:
return 1
return 5
31 changes: 16 additions & 15 deletions paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,29 @@
StrPath = Union[str, Path]


@dataclass
class Context:
"""A class to hold the context of a question."""

key: str
citation: str
context: str
text: str
score: int = 5

def __str__(self) -> str:
"""Return the context as a string."""
return self.context


@dataclass
class Answer:
"""A class to hold the answer to a question."""

question: str
answer: str = ""
context: str = ""
contexts: List[Any] = None
contexts: List[Context] = None
references: str = ""
formatted_answer: str = ""
passages: Dict[str, str] = None
Expand All @@ -29,17 +44,3 @@ def __post_init__(self):
def __str__(self) -> str:
"""Return the answer as a string."""
return self.formatted_answer


@dataclass
class Context:
"""A class to hold the context of a question."""

key: str
citation: str
context: str
text: str

def __str__(self) -> str:
"""Return the context as a string."""
return self.context
7 changes: 7 additions & 0 deletions paperqa/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import string
import re
import asyncio

import pypdf
Expand Down Expand Up @@ -80,3 +81,9 @@ async def sem_coro(coro):
return await coro

return await asyncio.gather(*(sem_coro(c) for c in coros))


def guess_is_4xx(msg: str) -> bool:
if re.search(r"4\d\d", msg):
return True
return False
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.12.0"
__version__ = "1.13.0"

0 comments on commit 88d50e8

Please sign in to comment.