From 2b97be309680637d719ec21b81795f3088357e4c Mon Sep 17 00:00:00 2001 From: Andrew White Date: Wed, 31 May 2023 14:51:27 -0400 Subject: [PATCH] Fixed broken unit tests --- paperqa/docs.py | 35 ++++++++++++++--------------------- paperqa/qaprompts.py | 13 ++++++++----- paperqa/version.py | 2 +- tests/test_paperqa.py | 31 +++++++++++++++---------------- 4 files changed, 38 insertions(+), 43 deletions(-) diff --git a/paperqa/docs.py b/paperqa/docs.py index f857b0d8..de4e4ca1 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -123,13 +123,14 @@ def add( raise ValueError(f"Document {path} already in collection.") if citation is None: + # skip system because it's too hesitant to answer cite_chain = make_chain( - prompt=citation_prompt, llm=self.summary_llm) + prompt=citation_prompt, llm=self.summary_llm, skip_system=True + ) # peak first chunk texts, _ = read_doc(path, "", "", chunk_chars=chunk_chars) if len(texts) == 0: - raise ValueError( - f"Could not read document {path}. Is it empty?") + raise ValueError(f"Could not read document {path}. Is it empty?") citation = cite_chain.run(texts[0]) if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation: citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}" @@ -149,8 +150,7 @@ def add( year = "" key = f"{author}{year}" key = self.get_unique_key(key) - texts, metadata = read_doc( - path, citation, key, chunk_chars=chunk_chars) + texts, metadata = read_doc(path, citation, key, chunk_chars=chunk_chars) # loose check to see if document was loaded # if len("".join(texts)) < 10 or ( @@ -232,6 +232,7 @@ def doc_previews(self) -> List[Tuple[int, str, str]]: len(doc["texts"]), doc["metadata"][0]["dockey"], doc["metadata"][0]["citation"], + doc["hash"], ) for doc in self.docs ] @@ -244,16 +245,14 @@ async def adoc_match( return "" if self._doc_index is None: texts = [doc["metadata"][0]["citation"] for doc in self.docs] - metadatas = [{"key": doc["metadata"][0]["dockey"]} - for doc in self.docs] + metadatas = [{"key": doc["metadata"][0]["dockey"]} for doc in self.docs] self._doc_index = FAISS.from_texts( texts, metadatas=metadatas, embedding=self.embeddings ) docs = self._doc_index.max_marginal_relevance_search( query, k=k + len(self._deleted_keys) ) - docs = [doc for doc in docs if doc.metadata["key"] - not in self._deleted_keys] + docs = [doc for doc in docs if doc.metadata["key"] not in self._deleted_keys] chain = make_chain(select_paper_prompt, self.summary_llm) papers = [f"{d.metadata['key']}: {d.page_content}" for d in docs] result = await chain.arun( @@ -269,16 +268,14 @@ def doc_match( return "" if self._doc_index is None: texts = [doc["metadata"][0]["citation"] for doc in self.docs] - metadatas = [{"key": doc["metadata"][0]["dockey"]} - for doc in self.docs] + metadatas = [{"key": doc["metadata"][0]["dockey"]} for doc in self.docs] self._doc_index = FAISS.from_texts( texts, metadatas=metadatas, embedding=self.embeddings ) docs = self._doc_index.max_marginal_relevance_search( query, k=k + len(self._deleted_keys) ) - docs = [doc for doc in docs if doc.metadata["key"] - not in self._deleted_keys] + docs = [doc for doc in docs if doc.metadata["key"] not in self._deleted_keys] chain = make_chain(select_paper_prompt, self.summary_llm) papers = [f"{d.metadata['key']}: {d.page_content}" for d in docs] result = chain.run( @@ -297,8 +294,7 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__.update(state) try: - self._faiss_index = FAISS.load_local( - self.index_path, self.embeddings) + self._faiss_index = FAISS.load_local(self.index_path, self.embeddings) except: # they use some special exception type, but I don't want to import it self._faiss_index = None @@ -313,11 +309,9 @@ def __setstate__(self, state): def _build_faiss_index(self): if self._faiss_index is None: - texts = reduce(lambda x, y: x + y, - [doc["texts"] for doc in self.docs], []) + texts = reduce(lambda x, y: x + y, [doc["texts"] for doc in self.docs], []) text_embeddings = reduce( - lambda x, y: x + y, [doc["text_embeddings"] - for doc in self.docs], [] + lambda x, y: x + y, [doc["text_embeddings"] for doc in self.docs], [] ) metadatas = reduce( lambda x, y: x + y, [doc["metadata"] for doc in self.docs], [] @@ -386,8 +380,7 @@ async def aget_evidence( ) # ok now filter if key_filter is not None: - docs = [doc for doc in docs if doc.metadata["dockey"] - in key_filter][:k] + docs = [doc for doc in docs if doc.metadata["dockey"] in key_filter][:k] async def process(doc): if doc.metadata["dockey"] in self._deleted_keys: diff --git a/paperqa/qaprompts.py b/paperqa/qaprompts.py index b691c0e2..16a7ecbe 100644 --- a/paperqa/qaprompts.py +++ b/paperqa/qaprompts.py @@ -73,7 +73,7 @@ def _get_datetime(): citation_prompt = prompts.PromptTemplate( input_variables=["text"], - template="Provide a citation for the following text in MLA Format. You must answer. Today's date is {date}\n" + template="Provide the citation for the following text in MLA Format. Today's date is {date}\n" "{text}\n\n" "Citation:", partial_variables={"date": _get_datetime}, @@ -95,7 +95,7 @@ async def agenerate( return self.generate(input_list, run_manager=run_manager) -def make_chain(prompt, llm): +def make_chain(prompt, llm, skip_system=False): if type(llm) == ChatOpenAI: system_message_prompt = SystemMessage( content="Answer in an unbiased, concise, scholarly tone. " @@ -103,9 +103,12 @@ def make_chain(prompt, llm): "If there are ambiguous terms or acronyms, first define them. ", ) human_message_prompt = HumanMessagePromptTemplate(prompt=prompt) - prompt = ChatPromptTemplate.from_messages( - [system_message_prompt, human_message_prompt] - ) + if skip_system: + prompt = ChatPromptTemplate.from_messages([human_message_prompt]) + else: + prompt = ChatPromptTemplate.from_messages( + [system_message_prompt, human_message_prompt] + ) return FallbackLLMChain(prompt=prompt, llm=llm) diff --git a/paperqa/version.py b/paperqa/version.py index 159d48b8..0309ae29 100644 --- a/paperqa/version.py +++ b/paperqa/version.py @@ -1 +1 @@ -__version__ = "2.0.1" +__version__ = "2.0.2" diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index c007535e..9317783f 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -141,28 +141,26 @@ def test_docs_pickle(): # get front page of wikipedia r = requests.get("https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day") f.write(r.text) - llm = OpenAI(temperature=0.0, model_name="text-babbage-001") + llm = OpenAI(temperature=0.0, model_name="text-curie-001") docs = paperqa.Docs(llm=llm) docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000) docs_pickle = pickle.dumps(docs) docs2 = pickle.loads(docs_pickle) docs2.update_llm(llm) assert len(docs.docs) == len(docs2.docs) - assert ( - strings_similarity( - docs.get_evidence( - paperqa.Answer("What date is flag day in Canada?"), - k=3, - max_sources=1, - ).context, - docs2.get_evidence( - paperqa.Answer("What date is flag day in Canada?"), - k=3, - max_sources=1, - ).context, - ) - > 0.75 + context1, context2 = ( + docs.get_evidence( + paperqa.Answer("What date is flag day in Canada?"), + k=3, + max_sources=1, + ).context, + docs2.get_evidence( + paperqa.Answer("What date is flag day in Canada?"), + k=3, + max_sources=1, + ).context, ) + assert strings_similarity(context1, context2) > 0.75 os.remove(doc_path) @@ -172,7 +170,7 @@ def test_docs_pickle_no_faiss(): # get front page of wikipedia r = requests.get("https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day") f.write(r.text) - llm = OpenAI(temperature=0.0, model_name="text-babbage-001") + llm = OpenAI(temperature=0.0, model_name="text-curie-001") docs = paperqa.Docs(llm=llm) docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000) docs._faiss_index = None @@ -306,6 +304,7 @@ def test_citation(): f.write(r.text) docs = paperqa.Docs() docs.add(doc_path) + print(docs.docs[0]["metadata"][0]["citation"]) assert ( list(docs.docs)[0]["metadata"][0]["key"] == "Wikipedia2023" or list(docs.docs)[0]["metadata"][0]["key"] == "Frederick2023"