Skip to content

Commit

Permalink
fixed test problems
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Feb 25, 2023
1 parent c7072e1 commit c65ef5c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
13 changes: 8 additions & 5 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,12 @@ def __init__(
index_path = Path.home() / ".paperqa" / name
self.index_path = index_path

def update_llm(self, llm: LLM, summary_llm: LLM) -> None:
def update_llm(self, llm: LLM, summary_llm: Optional[LLM] = None) -> None:
"""Update the LLM for answering questions."""
self.llm = llm
if summary_llm is None:
summary_llm = llm
self.summary_llm = summary_llm
self.summary_chain = LLMChain(prompt=summary_prompt, llm=summary_llm)
self.qa_chain = LLMChain(prompt=qa_prompt, llm=llm)
self.edit_chain = LLMChain(prompt=edit_prompt, llm=llm)
Expand Down Expand Up @@ -207,15 +211,14 @@ def generate_search_query(self, query: str) -> List[str]:

search_query = self.search_chain.run(question=query)
queries = [s for s in search_query.split("\n") if len(s) > 3]
if '"' in queries[0]:
# often they're numbered/encased in quotes
queries = [q[q.find('"') + 1 : q.rfind('"')] for q in queries]
# remove 2., 3. from queries
queries = [re.sub(r"^\d+\.\s*", "", q) for q in queries]
return queries

def query(
self,
query: str,
k: int = 5,
k: int = 10,
max_sources: int = 5,
length_prompt: str = "about 100 words",
progress: Callable[[str], str] = None,
Expand Down
7 changes: 4 additions & 3 deletions paperqa/qaprompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
search_prompt = prompts.PromptTemplate(
input_variables=["question"],
template="We want to answer the following question: {question} \n"
"Provide a list of three targeted Google scholar searches (one search per line) "
"that will find papers that help answer the question. The current year is 2023."
"Search terms:\n",
"Provide three different targeted keyword searches (one search per line) "
"that will find papers that help answer the question. Do not use boolean operators. "
"The current year is 2023.\n\n"
"1.",
)
7 changes: 5 additions & 2 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ def test_docs_pickle():
# get front page of wikipedia
r = requests.get("https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day")
f.write(r.text)
docs = paperqa.Docs(llm=OpenAI(temperature=0.0, model_name="text-babbage-001"))
llm = llm = OpenAI(temperature=0.0, model_name="text-babbage-001")
docs = paperqa.Docs(llm=llm)
docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now")
docs_pickle = pickle.dumps(docs)
docs2 = pickle.loads(docs_pickle)
docs2.update_llm(llm)
assert len(docs.docs) == len(docs2.docs)
assert (
strings_similarity(
Expand Down Expand Up @@ -151,6 +153,7 @@ def test_prompt_length():
docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now")
docs.query("What is the name of the politician?", length_prompt="25 words")


def test_doc_preview():
doc_path = "example.txt"
with open(doc_path, "w", encoding="utf-8") as f:
Expand All @@ -159,4 +162,4 @@ def test_doc_preview():
f.write(r.text)
docs = paperqa.Docs(llm=OpenAI(temperature=0.0, model_name="text-ada-001"))
docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now")
assert len(docs.doc_previews) == 1
assert len(docs.doc_previews) == 1

0 comments on commit c65ef5c

Please sign in to comment.