Skip to content

Commit

Permalink
More improvements for reducing memory usage (#160)
Browse files Browse the repository at this point in the history
* Busted implementation

* Fixed key bug

* Fixed old unit tests too
  • Loading branch information
whitead authored Jul 12, 2023
1 parent a922ea4 commit b4124c8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 17 deletions.
29 changes: 22 additions & 7 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Docs(BaseModel, arbitrary_types_allowed=True, smart_union=True):
prompts: PromptCollection = PromptCollection()
memory: bool = False
memory_model: Optional[BaseChatMemory] = None
jit_texts_index: bool = False

# TODO: Not sure how to get this to work
# while also passing mypy checks
Expand Down Expand Up @@ -83,6 +84,11 @@ def check_memory_model(cls, v, values):
return values["memory_model"]
return None

def clear_docs(self):
self.texts = []
self.docs = {}
self.docnames = set()

def update_llm(
self,
llm: Union[BaseLanguageModel, str],
Expand Down Expand Up @@ -280,9 +286,9 @@ async def adoc_match(
self, query: str, k: int = 25, get_callbacks: CallbackFactory = lambda x: None
) -> Set[DocKey]:
"""Return a list of dockeys that match the query."""
if len(self.docs) == 0:
return set()
if self.doc_index is None:
if len(self.docs) == 0:
return set()
texts = [doc.citation for doc in self.docs.values()]
metadatas = [d.dict() for d in self.docs.values()]
self.doc_index = FAISS.from_texts(
Expand Down Expand Up @@ -329,11 +335,19 @@ def __setstate__(self, state):
self.texts_index = None
self.doc_index = None

def _build_texts_index(self):
def _build_texts_index(self, keys: Optional[Set[DocKey]] = None):
if keys is not None and self.jit_texts_index:
del self.texts_index
self.texts_index = None
if self.texts_index is None:
raw_texts = [t.text for t in self.texts]
text_embeddings = [t.embeddings for t in self.texts]
metadatas = [t.dict(exclude={"embeddings", "text"}) for t in self.texts]
texts = self.texts
if keys is not None:
texts = [t for t in texts if t.doc.dockey in keys]
if len(texts) == 0:
return
raw_texts = [t.text for t in texts]
text_embeddings = [t.embeddings for t in texts]
metadatas = [t.dict(exclude={"embeddings", "text"}) for t in texts]
self.texts_index = FAISS.from_embeddings(
# wow adding list to the zip was tricky
text_embeddings=list(zip(raw_texts, text_embeddings)),
Expand Down Expand Up @@ -384,8 +398,9 @@ async def aget_evidence(
) -> Answer:
if len(self.docs) == 0 and self.doc_index is None:
return answer
self._build_texts_index(keys=answer.dockey_filter)
if self.texts_index is None:
self._build_texts_index()
return answer
self.texts_index = cast(VectorStore, self.texts_index)
_k = k
if answer.dockey_filter is not None:
Expand Down
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.2.1"
__version__ = "3.3.0"
40 changes: 31 additions & 9 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_update_llm():


def test_evidence():
doc_path = "example.txt"
doc_path = "example.html"
with open(doc_path, "w", encoding="utf-8") as f:
# get wiki page about politician
r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)")
Expand All @@ -202,6 +202,7 @@ def test_evidence():
evidence = docs.get_evidence(
Answer(question="For which state was Bates a governor?"), k=1, max_sources=1
)
print(evidence.context)
assert "Missouri" in evidence.context
os.remove(doc_path)

Expand Down Expand Up @@ -256,7 +257,7 @@ def test_adoc_match(self):


def test_docs_pickle():
doc_path = "example.txt"
doc_path = "example.html"
with open(doc_path, "w", encoding="utf-8") as f:
# get front page of wikipedia
r = requests.get("https://en.wikipedia.org/wiki/Take_Your_Dog_to_Work_Day")
Expand Down Expand Up @@ -295,7 +296,7 @@ def test_docs_pickle():


def test_docs_pickle_no_faiss():
doc_path = "example.txt"
doc_path = "example.html"
with open(doc_path, "w", encoding="utf-8") as f:
# get front page of wikipedia
r = requests.get("https://en.wikipedia.org/wiki/Take_Your_Dog_to_Work_Day")
Expand Down Expand Up @@ -592,13 +593,13 @@ def test_custom_prompts():

docs = Docs(prompts=PromptCollection(qa=my_qaprompt))

doc_path = "example.txt"
doc_path = "example.html"
with open(doc_path, "w", encoding="utf-8") as f:
# get wiki page about politician
r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)")
f.write(r.text)
docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now")
answer = docs.query("What country is Bates from?")
answer = docs.query("What country is Frederick Bates from?")
print(answer.answer)
assert "United States" in answer.answer

Expand Down Expand Up @@ -643,26 +644,26 @@ def test_post_prompt():


def test_memory():
docs = Docs(memory=True, k=3, max_sources=1, llm="gpt-3.5-turbo", key_filter=False)
# Not sure why, but gpt-3.5 cannot do this anymore.
docs = Docs(memory=True, k=3, max_sources=1, llm="gpt-4", key_filter=False)
docs.add_url(
"https://en.wikipedia.org/wiki/Red_Army",
citation="WikiMedia Foundation, 2023, Accessed now",
dockey="test",
)
answer1 = docs.query("When did the Soviet Union and Japan agree to a cease-fire?")
print(answer1.answer)
assert answer1.memory is not None
assert "1939" in answer1.answer
assert "Answer" in docs.memory_model.load_memory_variables({})["memory"]
answer2 = docs.query("When was the conflict resolved?")
answer2 = docs.query("When was it resolved?")
assert "1941" in answer2.answer or "1945" in answer2.answer
assert answer2.memory is not None
assert "Answer" in docs.memory_model.load_memory_variables({})["memory"]
print(answer2.answer)

docs.clear_memory()

answer3 = docs.query("When was the conflict resolved?")
answer3 = docs.query("When was it resolved?")
assert answer3.memory is not None
assert (
"I cannot answer" in answer3.answer
Expand Down Expand Up @@ -721,3 +722,24 @@ def test_external_doc_index():
assert len(docs2.docs) == 0
evidence = docs2.query("What is the date of flag day?", key_filter=True)
assert "February 15" in evidence.context


def test_external_texts_index():
docs = Docs(jit_texts_index=True)
docs.add_url(
"https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day",
citation="Flag Day of Canada, WikiMedia Foundation, 2023, Accessed now",
)
answer = docs.query(query="What is the date of flag day?", key_filter=True)
assert "February 15" in answer.answer

docs.add_url(
"https://en.wikipedia.org/wiki/Frederick_Bates_(politician)",
citation="Fredrick Bates, WikiMedia Foundation, 2023, Accessed now",
)

answer = docs.query(query="What is the date of flag day?", key_filter=False)
assert "February 15" in answer.answer

answer = docs.query(query="What is the date of flag day?", key_filter=True)
assert "February 15" in answer.answer

0 comments on commit b4124c8

Please sign in to comment.