Skip to content

Commit

Permalink
Fixed broken unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed May 31, 2023
1 parent 8e46d8c commit 2b97be3
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 43 deletions.
35 changes: 14 additions & 21 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,14 @@ def add(
raise ValueError(f"Document {path} already in collection.")

if citation is None:
# skip system because it's too hesitant to answer
cite_chain = make_chain(
prompt=citation_prompt, llm=self.summary_llm)
prompt=citation_prompt, llm=self.summary_llm, skip_system=True
)
# peak first chunk
texts, _ = read_doc(path, "", "", chunk_chars=chunk_chars)
if len(texts) == 0:
raise ValueError(
f"Could not read document {path}. Is it empty?")
raise ValueError(f"Could not read document {path}. Is it empty?")
citation = cite_chain.run(texts[0])
if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation:
citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}"
Expand All @@ -149,8 +150,7 @@ def add(
year = ""
key = f"{author}{year}"
key = self.get_unique_key(key)
texts, metadata = read_doc(
path, citation, key, chunk_chars=chunk_chars)
texts, metadata = read_doc(path, citation, key, chunk_chars=chunk_chars)
# loose check to see if document was loaded
#
if len("".join(texts)) < 10 or (
Expand Down Expand Up @@ -232,6 +232,7 @@ def doc_previews(self) -> List[Tuple[int, str, str]]:
len(doc["texts"]),
doc["metadata"][0]["dockey"],
doc["metadata"][0]["citation"],
doc["hash"],
)
for doc in self.docs
]
Expand All @@ -244,16 +245,14 @@ async def adoc_match(
return ""
if self._doc_index is None:
texts = [doc["metadata"][0]["citation"] for doc in self.docs]
metadatas = [{"key": doc["metadata"][0]["dockey"]}
for doc in self.docs]
metadatas = [{"key": doc["metadata"][0]["dockey"]} for doc in self.docs]
self._doc_index = FAISS.from_texts(
texts, metadatas=metadatas, embedding=self.embeddings
)
docs = self._doc_index.max_marginal_relevance_search(
query, k=k + len(self._deleted_keys)
)
docs = [doc for doc in docs if doc.metadata["key"]
not in self._deleted_keys]
docs = [doc for doc in docs if doc.metadata["key"] not in self._deleted_keys]
chain = make_chain(select_paper_prompt, self.summary_llm)
papers = [f"{d.metadata['key']}: {d.page_content}" for d in docs]
result = await chain.arun(
Expand All @@ -269,16 +268,14 @@ def doc_match(
return ""
if self._doc_index is None:
texts = [doc["metadata"][0]["citation"] for doc in self.docs]
metadatas = [{"key": doc["metadata"][0]["dockey"]}
for doc in self.docs]
metadatas = [{"key": doc["metadata"][0]["dockey"]} for doc in self.docs]
self._doc_index = FAISS.from_texts(
texts, metadatas=metadatas, embedding=self.embeddings
)
docs = self._doc_index.max_marginal_relevance_search(
query, k=k + len(self._deleted_keys)
)
docs = [doc for doc in docs if doc.metadata["key"]
not in self._deleted_keys]
docs = [doc for doc in docs if doc.metadata["key"] not in self._deleted_keys]
chain = make_chain(select_paper_prompt, self.summary_llm)
papers = [f"{d.metadata['key']}: {d.page_content}" for d in docs]
result = chain.run(
Expand All @@ -297,8 +294,7 @@ def __getstate__(self):
def __setstate__(self, state):
self.__dict__.update(state)
try:
self._faiss_index = FAISS.load_local(
self.index_path, self.embeddings)
self._faiss_index = FAISS.load_local(self.index_path, self.embeddings)
except:
# they use some special exception type, but I don't want to import it
self._faiss_index = None
Expand All @@ -313,11 +309,9 @@ def __setstate__(self, state):

def _build_faiss_index(self):
if self._faiss_index is None:
texts = reduce(lambda x, y: x + y,
[doc["texts"] for doc in self.docs], [])
texts = reduce(lambda x, y: x + y, [doc["texts"] for doc in self.docs], [])
text_embeddings = reduce(
lambda x, y: x + y, [doc["text_embeddings"]
for doc in self.docs], []
lambda x, y: x + y, [doc["text_embeddings"] for doc in self.docs], []
)
metadatas = reduce(
lambda x, y: x + y, [doc["metadata"] for doc in self.docs], []
Expand Down Expand Up @@ -386,8 +380,7 @@ async def aget_evidence(
)
# ok now filter
if key_filter is not None:
docs = [doc for doc in docs if doc.metadata["dockey"]
in key_filter][:k]
docs = [doc for doc in docs if doc.metadata["dockey"] in key_filter][:k]

async def process(doc):
if doc.metadata["dockey"] in self._deleted_keys:
Expand Down
13 changes: 8 additions & 5 deletions paperqa/qaprompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _get_datetime():

citation_prompt = prompts.PromptTemplate(
input_variables=["text"],
template="Provide a citation for the following text in MLA Format. You must answer. Today's date is {date}\n"
template="Provide the citation for the following text in MLA Format. Today's date is {date}\n"
"{text}\n\n"
"Citation:",
partial_variables={"date": _get_datetime},
Expand All @@ -95,17 +95,20 @@ async def agenerate(
return self.generate(input_list, run_manager=run_manager)


def make_chain(prompt, llm):
def make_chain(prompt, llm, skip_system=False):
if type(llm) == ChatOpenAI:
system_message_prompt = SystemMessage(
content="Answer in an unbiased, concise, scholarly tone. "
"You may refuse to answer if there is insufficient information. "
"If there are ambiguous terms or acronyms, first define them. ",
)
human_message_prompt = HumanMessagePromptTemplate(prompt=prompt)
prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)
if skip_system:
prompt = ChatPromptTemplate.from_messages([human_message_prompt])
else:
prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)
return FallbackLLMChain(prompt=prompt, llm=llm)


Expand Down
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.0.1"
__version__ = "2.0.2"
31 changes: 15 additions & 16 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,28 +141,26 @@ def test_docs_pickle():
# get front page of wikipedia
r = requests.get("https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day")
f.write(r.text)
llm = OpenAI(temperature=0.0, model_name="text-babbage-001")
llm = OpenAI(temperature=0.0, model_name="text-curie-001")
docs = paperqa.Docs(llm=llm)
docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000)
docs_pickle = pickle.dumps(docs)
docs2 = pickle.loads(docs_pickle)
docs2.update_llm(llm)
assert len(docs.docs) == len(docs2.docs)
assert (
strings_similarity(
docs.get_evidence(
paperqa.Answer("What date is flag day in Canada?"),
k=3,
max_sources=1,
).context,
docs2.get_evidence(
paperqa.Answer("What date is flag day in Canada?"),
k=3,
max_sources=1,
).context,
)
> 0.75
context1, context2 = (
docs.get_evidence(
paperqa.Answer("What date is flag day in Canada?"),
k=3,
max_sources=1,
).context,
docs2.get_evidence(
paperqa.Answer("What date is flag day in Canada?"),
k=3,
max_sources=1,
).context,
)
assert strings_similarity(context1, context2) > 0.75
os.remove(doc_path)


Expand All @@ -172,7 +170,7 @@ def test_docs_pickle_no_faiss():
# get front page of wikipedia
r = requests.get("https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day")
f.write(r.text)
llm = OpenAI(temperature=0.0, model_name="text-babbage-001")
llm = OpenAI(temperature=0.0, model_name="text-curie-001")
docs = paperqa.Docs(llm=llm)
docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000)
docs._faiss_index = None
Expand Down Expand Up @@ -306,6 +304,7 @@ def test_citation():
f.write(r.text)
docs = paperqa.Docs()
docs.add(doc_path)
print(docs.docs[0]["metadata"][0]["citation"])
assert (
list(docs.docs)[0]["metadata"][0]["key"] == "Wikipedia2023"
or list(docs.docs)[0]["metadata"][0]["key"] == "Frederick2023"
Expand Down

0 comments on commit 2b97be3

Please sign in to comment.