Skip to content

Commit

Permalink
Fix: dynamically increase query params for higher k values (#131)
Browse files Browse the repository at this point in the history
* fix: return enough results if k > ncells*32

* fix: increase both ndocs and ncells to match k

* chore: prepare release

* linting

* chore: saner ncells for larger datasets

* linting
  • Loading branch information
bclavie authored Feb 11, 2024
1 parent 9ef207d commit 5409914
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "RAGatouille"
version = "0.0.6c1"
version = "0.0.6c2"
description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts."
authors = ["Benjamin Clavie <ben@clavie.eu>"]
license = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion ragatouille/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.6c1"
__version__ = "0.0.6c2"
from .RAGPretrainedModel import RAGPretrainedModel
from .RAGTrainer import RAGTrainer

Expand Down
28 changes: 24 additions & 4 deletions ragatouille/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,14 +422,14 @@ def _load_searcher(
)

if not force_fast:
self.searcher.configure(ndocs=1024)
self.searcher.configure(ncells=16)
if len(self.searcher.collection) < 10000:
self.searcher.configure(ncells=4)
self.searcher.configure(ncells=8)
self.searcher.configure(centroid_score_threshold=0.4)
self.searcher.configure(ndocs=512)
elif len(self.searcher.collection) < 100000:
self.searcher.configure(ncells=2)
self.searcher.configure(ncells=4)
self.searcher.configure(centroid_score_threshold=0.45)
self.searcher.configure(ndocs=1024)
# Otherwise, use defaults for k
else:
# Use fast settingss
Expand Down Expand Up @@ -459,6 +459,22 @@ def search(
for doc_id in doc_ids:
pids.extend(self.docid_pid_map[doc_id])

base_ncells = self.searcher.config.ncells
base_ndocs = self.searcher.config.ndocs

if k > len(self.searcher.collection):
print(
"WARNING: k value is larger than the number of documents in the index!",
f"Lowering k to {len(self.searcher.collection)}...",
)
k = len(self.searcher.collection)

# For smaller collections, we need a higher ncells value to ensure we return enough results
if k > (32 * self.searcher.config.ncells):
self.searcher.configure(ncells=min((k // 32 + 2), base_ncells))

self.searcher.configure(ndocs=max(k * 4, base_ndocs))

if isinstance(query, str):
results = [self._search(query, k, pids)]
else:
Expand Down Expand Up @@ -487,6 +503,10 @@ def search(

to_return.append(result_for_query)

# Restore original ncells&ndocs if it had to be changed for large k values
self.searcher.configure(ncells=base_ncells)
self.searcher.configure(ndocs=base_ndocs)

if len(to_return) == 1:
return to_return[0]
return to_return
Expand Down

0 comments on commit 5409914

Please sign in to comment.