Skip to content

Commit

Permalink
Made marginal search optional
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Feb 26, 2023
1 parent 06de777 commit 4eeed82
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 15 deletions.
57 changes: 46 additions & 11 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from langchain.llms.base import LLM
from langchain.chains import LLMChain
from langchain.callbacks import get_openai_callback
from langchain.cache import InMemoryCache
import langchain

langchain.llm_cache = InMemoryCache()


@dataclass
Expand Down Expand Up @@ -200,14 +204,20 @@ def get_evidence(
answer: Answer,
k: int = 3,
max_sources: int = 5,
marginal_relevance: bool = True,
) -> str:
if self._faiss_index is None:
self._build_faiss_index()

# want to work through indices but less k
docs = self._faiss_index.max_marginal_relevance_search(
answer.question, k=k, fetch_k=5 * k
)
if marginal_relevance:
docs = self._faiss_index.max_marginal_relevance_search(
answer.question, k=k, fetch_k=5 * k
)
else:
docs = self._faiss_index.similarity_search(
answer.question, k=k, fetch_k=5 * k
)
for doc in docs:
c = (
doc.metadata["key"],
Expand Down Expand Up @@ -251,9 +261,14 @@ def query_gen(
k: int = 10,
max_sources: int = 5,
length_prompt: str = "about 100 words",
marginal_relevance: bool = True,
):
yield from self._query(
query, k=k, max_sources=max_sources, length_prompt=length_prompt
query,
k=k,
max_sources=max_sources,
length_prompt=length_prompt,
marginal_relevance=marginal_relevance,
)

def query(
Expand All @@ -262,20 +277,37 @@ def query(
k: int = 10,
max_sources: int = 5,
length_prompt: str = "about 100 words",
marginal_relevance: bool = True,
):
for answer in self._query(
query, k=k, max_sources=max_sources, length_prompt=length_prompt
query,
k=k,
max_sources=max_sources,
length_prompt=length_prompt,
marginal_relevance=marginal_relevance,
):
pass
return answer

def _query(self, query: str, k: int, max_sources: int, length_prompt: str):
def _query(
self,
query: str,
k: int,
max_sources: int,
length_prompt: str,
marginal_relevance: bool,
):
if k < max_sources:
raise ValueError("k should be greater than max_sources")
tokens = 0
answer = Answer(query)
with get_openai_callback() as cb:
for answer in self.get_evidence(answer, k=k, max_sources=max_sources):
for answer in self.get_evidence(
answer,
k=k,
max_sources=max_sources,
marginal_relevance=marginal_relevance,
):
yield answer
tokens += cb.total_tokens
context_str, citations = answer.context, answer.contexts
Expand All @@ -290,11 +322,14 @@ def _query(self, query: str, k: int, max_sources: int, length_prompt: str):
answer_text = self.qa_chain.run(
question=query, context_str=context_str, length=length_prompt
)[1:]
if maybe_is_truncated(answer_text):
answer_text = self.edit_chain.run(
question=query, answer=answer_text
)
# if maybe_is_truncated(answer_text):
# answer_text = self.edit_chain.run(
# question=query, answer=answer_text
# )
tokens += cb.total_tokens
# it still happens lol
if "(Foo2012)" in answer_text:
answer_text = answer_text.replace("(Foo2012)", "")
for key, citation, summary, text in citations:
# do check for whole key (so we don't catch Callahan2019a with Callahan2019)
skey = key.split(" ")[0]
Expand Down
5 changes: 1 addition & 4 deletions paperqa/qaprompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,8 @@
"For each sentence in your answer, indicate which sources most support it "
"via valid citation markers at the end of sentences, like (Foo2012). "
"Answer in an unbiased, balanced, and scientific tone. "
"Use Markdown for formatting code or text. "
# "write a complete unbiased answer prefixed by \"Answer:\""
"\n--------------------\n"
"Use Markdown for formatting code or text.\n\n"
"{context_str}\n"
"----------------------\n"
"Question: {question}\n"
"Answer: ",
)
Expand Down

0 comments on commit 4eeed82

Please sign in to comment.