Skip to content

Commit

Permalink
Added max concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed May 28, 2023
1 parent 1dc6ad0 commit bda7bef
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
11 changes: 9 additions & 2 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from .readers import read_doc
from .types import Answer, Context
from .utils import maybe_is_text, md5sum
from .utils import maybe_is_text, md5sum, gather_with_concurrency

os.makedirs(os.path.dirname(CACHE_PATH), exist_ok=True)
langchain.llm_cache = SQLiteCache(CACHE_PATH)
Expand All @@ -47,6 +47,7 @@ def __init__(
name: str = "default",
index_path: Optional[Path] = None,
embeddings: Optional[Embeddings] = None,
max_concurrent: int = 5,
) -> None:
"""Initialize the collection of documents.
Expand All @@ -57,6 +58,7 @@ def __init__(
name: The name of the collection.
index_path: The path to the index file IF pickled. If None, defaults to using name in $HOME/.paperqa/name
embeddings: The embeddings to use for indexing documents. Default - OpenAI embeddings
max_concurrent: Number of concurrent LLM model calls to make
"""
self.docs = []
self.chunk_size_limit = chunk_size_limit
Expand All @@ -71,6 +73,7 @@ def __init__(
if embeddings is None:
embeddings = OpenAIEmbeddings()
self.embeddings = embeddings
self.max_concurrent = max_concurrent
self._deleted_keys = set()

def update_llm(
Expand Down Expand Up @@ -295,6 +298,8 @@ def __setstate__(self, state):
# must be a better way to have backwards compatibility
if not hasattr(self, "_deleted_keys"):
self._deleted_keys = set()
if not hasattr(self, "max_concurrent"):
self.max_concurrent = 5
self.update_llm(None, None)

def _build_faiss_index(self):
Expand Down Expand Up @@ -396,7 +401,9 @@ async def process(doc):
return c, callbacks[0]
return None, None

results = await asyncio.gather(*[process(doc) for doc in docs])
results = await gather_with_concurrency(
self.max_concurrent, *[process(doc) for doc in docs]
)
# filter out failures
results = [r for r in results if r[0] is not None]
answer.tokens += sum([cb.total_tokens for _, cb in results])
Expand Down
12 changes: 12 additions & 0 deletions paperqa/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import string
import asyncio

import pypdf

Expand Down Expand Up @@ -68,3 +69,14 @@ def md5sum(file_path: StrPath) -> str:

with open(file_path, "rb") as f:
return hashlib.md5(f.read()).hexdigest()


async def gather_with_concurrency(n, *coros):
# https://stackoverflow.com/a/61478547/2392535
semaphore = asyncio.Semaphore(n)

async def sem_coro(coro):
async with semaphore:
return await coro

return await asyncio.gather(*(sem_coro(c) for c in coros))
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.11.0"
__version__ = "1.12.0"

0 comments on commit bda7bef

Please sign in to comment.