From 88d50e82cb8aca959d12eb66c0a6132b7f514394 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Tue, 30 May 2023 10:26:46 -0400 Subject: [PATCH] Alow failure on evidence gathering without crashing (#129) * Alow failure on evidence gathering without crashing * Fixed errors on key filters --- paperqa/docs.py | 41 +++++++++++++++++++++++++++++------------ paperqa/qaprompts.py | 31 ++++++++++++++++++++----------- paperqa/types.py | 31 ++++++++++++++++--------------- paperqa/utils.py | 7 +++++++ paperqa/version.py | 2 +- 5 files changed, 73 insertions(+), 39 deletions(-) diff --git a/paperqa/docs.py b/paperqa/docs.py index def9206b..5f3a030d 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -27,10 +27,11 @@ search_prompt, select_paper_prompt, summary_prompt, + get_score, ) from .readers import read_doc from .types import Answer, Context -from .utils import maybe_is_text, md5sum, gather_with_concurrency +from .utils import maybe_is_text, md5sum, gather_with_concurrency, guess_is_4xx os.makedirs(os.path.dirname(CACHE_PATH), exist_ok=True) langchain.llm_cache = SQLiteCache(CACHE_PATH) @@ -373,12 +374,13 @@ async def aget_evidence( docs = self._faiss_index.similarity_search( answer.question, k=_k, fetch_k=5 * _k ) + # ok now filter + if key_filter is not None: + docs = [doc for doc in docs if doc.metadata["dockey"] in key_filter][:k] async def process(doc): if doc.metadata["dockey"] in self._deleted_keys: return None, None - if key_filter is not None and doc.metadata["dockey"] not in key_filter: - return None, None # check if it is already in answer (possible in agent setting) if doc.metadata["key"] in [c.key for c in answer.contexts]: return None, None @@ -386,18 +388,32 @@ async def process(doc): "evidence:" + doc.metadata["key"] ) summary_chain = make_chain(summary_prompt, self.summary_llm) - c = Context( - key=doc.metadata["key"], - citation=doc.metadata["citation"], - context=await summary_chain.arun( + # This is dangerous because it + # could mask errors that are important + # I also cannot know what the exception + # type is because any model could be used + # my best idea is see if there is a 4XX + # http code in the exception + try: + context = await summary_chain.arun( question=answer.question, context_str=doc.page_content, citation=doc.metadata["citation"], callbacks=callbacks, - ), + ) + except Exception as e: + if guess_is_4xx(e): + return None, None + raise e + c = Context( + key=doc.metadata["key"], + citation=doc.metadata["citation"], + context=context, text=doc.page_content, + score=get_score(context), ) if "not applicable" not in c.context.casefold(): + print(c.score) return c, callbacks[0] return None, None @@ -411,7 +427,7 @@ async def process(doc): contexts = [c for c, _ in results if c is not None] if len(contexts) == 0: return answer - contexts = sorted(contexts, key=lambda x: len(x.context), reverse=True) + contexts = sorted(contexts, key=lambda x: x.score, reverse=True) contexts = contexts[:max_sources] # add to answer (if not already there) keys = [c.key for c in answer.contexts] @@ -499,11 +515,12 @@ async def aquery( if answer is None: answer = Answer(query) if len(answer.contexts) == 0: - if key_filter or (key_filter is None and len(self.docs) > 5): + if key_filter or (key_filter is None and len(self.docs) > max_sources): callbacks = [OpenAICallbackHandler()] + get_callbacks("filter") keys = await self.adoc_match(answer.question, callbacks=callbacks) answer.tokens += callbacks[0].total_tokens answer.cost += callbacks[0].total_cost + key_filter = True if len(keys) > 0 else False answer = await self.aget_evidence( answer, k=k, @@ -532,8 +549,8 @@ async def aquery( answer.tokens += cb.total_tokens answer.cost += cb.total_cost # it still happens lol - if "(Foo2012)" in answer_text: - answer_text = answer_text.replace("(Foo2012)", "") + if "(Example2012)" in answer_text: + answer_text = answer_text.replace("(Example2012)", "") for c in contexts: key = c.key text = c.context diff --git a/paperqa/qaprompts.py b/paperqa/qaprompts.py index aed25240..b691c0e2 100644 --- a/paperqa/qaprompts.py +++ b/paperqa/qaprompts.py @@ -1,6 +1,7 @@ import copy from datetime import datetime from typing import Any, Dict, List, Optional +import re import langchain.prompts as prompts from langchain.callbacks.manager import AsyncCallbackManagerForChainRun @@ -11,12 +12,12 @@ summary_prompt = prompts.PromptTemplate( input_variables=["question", "context_str", "citation"], - template="Summarize and provide direct quotes from the text below to help answer a question. " - "Do not directly answer the question, instead summarize and " - "quote to give evidence to help answer the question. " - "Do not use outside sources. " - 'Reply with only "Not applicable" if the text is unrelated to the question. ' - "Use 100 or less words." + template="Summarize the text below to help answer a question. " + "Do not directly answer the question, instead summarize " + "to give evidence to help answer the question. Include direct quotes. " + 'Reply "Not applicable" if text is irrelevant. ' + "Use around 100 words. At the end of your response, provide a score from 1-10 on a newline " + "indicating relevance to question. Do not explain your score. " "\n\n" "{context_str}\n" "Extracted from {citation}\n" @@ -24,14 +25,13 @@ "Relevant Information Summary:", ) - qa_prompt = prompts.PromptTemplate( input_variables=["question", "context_str", "length"], template="Write an answer ({length}) " "for the question below based on the provided context. " "If the context provides insufficient information, " 'reply "I cannot answer". ' - "For each sentence in your answer, indicate which sources most support it " + "For each part of your answer, indicate which sources most support it " "via valid citation markers at the end of sentences, like (Example2012). " "Answer in an unbiased, comprehensive, and scholarly tone. " "If the question is subjective, provide an opinionated answer in the concluding 1-2 sentences. " @@ -98,12 +98,21 @@ async def agenerate( def make_chain(prompt, llm): if type(llm) == ChatOpenAI: system_message_prompt = SystemMessage( - content="You are a scholarly researcher that answers in an unbiased, concise, scholarly tone. " - "You sometimes refuse to answer if there is insufficient information. " - "If there are potentially ambiguous terms or acronyms, first define them. ", + content="Answer in an unbiased, concise, scholarly tone. " + "You may refuse to answer if there is insufficient information. " + "If there are ambiguous terms or acronyms, first define them. ", ) human_message_prompt = HumanMessagePromptTemplate(prompt=prompt) prompt = ChatPromptTemplate.from_messages( [system_message_prompt, human_message_prompt] ) return FallbackLLMChain(prompt=prompt, llm=llm) + + +def get_score(text): + score = re.search(r"[sS]core[:is\s]+([0-9]+)", text) + if score: + return int(score.group(1)) + if len(text) < 100: + return 1 + return 5 diff --git a/paperqa/types.py b/paperqa/types.py index ac66b7c8..44c6556e 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -5,6 +5,21 @@ StrPath = Union[str, Path] +@dataclass +class Context: + """A class to hold the context of a question.""" + + key: str + citation: str + context: str + text: str + score: int = 5 + + def __str__(self) -> str: + """Return the context as a string.""" + return self.context + + @dataclass class Answer: """A class to hold the answer to a question.""" @@ -12,7 +27,7 @@ class Answer: question: str answer: str = "" context: str = "" - contexts: List[Any] = None + contexts: List[Context] = None references: str = "" formatted_answer: str = "" passages: Dict[str, str] = None @@ -29,17 +44,3 @@ def __post_init__(self): def __str__(self) -> str: """Return the answer as a string.""" return self.formatted_answer - - -@dataclass -class Context: - """A class to hold the context of a question.""" - - key: str - citation: str - context: str - text: str - - def __str__(self) -> str: - """Return the context as a string.""" - return self.context diff --git a/paperqa/utils.py b/paperqa/utils.py index 7a7ef421..9e741612 100644 --- a/paperqa/utils.py +++ b/paperqa/utils.py @@ -1,5 +1,6 @@ import math import string +import re import asyncio import pypdf @@ -80,3 +81,9 @@ async def sem_coro(coro): return await coro return await asyncio.gather(*(sem_coro(c) for c in coros)) + + +def guess_is_4xx(msg: str) -> bool: + if re.search(r"4\d\d", msg): + return True + return False diff --git a/paperqa/version.py b/paperqa/version.py index b518f6ee..9a34ccc9 100644 --- a/paperqa/version.py +++ b/paperqa/version.py @@ -1 +1 @@ -__version__ = "1.12.0" +__version__ = "1.13.0"