From 5409914bfd344ab0d8cac8763f40e77db6a099b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Benjamin=20Clavi=C3=A9?= Date: Sun, 11 Feb 2024 21:53:31 +0100 Subject: [PATCH] Fix: dynamically increase query params for higher `k` values (#131) * fix: return enough results if k > ncells*32 * fix: increase both ndocs and ncells to match k * chore: prepare release * linting * chore: saner ncells for larger datasets * linting --- pyproject.toml | 2 +- ragatouille/__init__.py | 2 +- ragatouille/models/colbert.py | 28 ++++++++++++++++++++++++---- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index da47dfd..5f1b10e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "RAGatouille" -version = "0.0.6c1" +version = "0.0.6c2" description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts." authors = ["Benjamin Clavie "] license = "Apache-2.0" diff --git a/ragatouille/__init__.py b/ragatouille/__init__.py index b2464c8..aba0f21 100644 --- a/ragatouille/__init__.py +++ b/ragatouille/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.6c1" +__version__ = "0.0.6c2" from .RAGPretrainedModel import RAGPretrainedModel from .RAGTrainer import RAGTrainer diff --git a/ragatouille/models/colbert.py b/ragatouille/models/colbert.py index d286851..3f2cbd5 100644 --- a/ragatouille/models/colbert.py +++ b/ragatouille/models/colbert.py @@ -422,14 +422,14 @@ def _load_searcher( ) if not force_fast: + self.searcher.configure(ndocs=1024) + self.searcher.configure(ncells=16) if len(self.searcher.collection) < 10000: - self.searcher.configure(ncells=4) + self.searcher.configure(ncells=8) self.searcher.configure(centroid_score_threshold=0.4) - self.searcher.configure(ndocs=512) elif len(self.searcher.collection) < 100000: - self.searcher.configure(ncells=2) + self.searcher.configure(ncells=4) self.searcher.configure(centroid_score_threshold=0.45) - self.searcher.configure(ndocs=1024) # Otherwise, use defaults for k else: # Use fast settingss @@ -459,6 +459,22 @@ def search( for doc_id in doc_ids: pids.extend(self.docid_pid_map[doc_id]) + base_ncells = self.searcher.config.ncells + base_ndocs = self.searcher.config.ndocs + + if k > len(self.searcher.collection): + print( + "WARNING: k value is larger than the number of documents in the index!", + f"Lowering k to {len(self.searcher.collection)}...", + ) + k = len(self.searcher.collection) + + # For smaller collections, we need a higher ncells value to ensure we return enough results + if k > (32 * self.searcher.config.ncells): + self.searcher.configure(ncells=min((k // 32 + 2), base_ncells)) + + self.searcher.configure(ndocs=max(k * 4, base_ndocs)) + if isinstance(query, str): results = [self._search(query, k, pids)] else: @@ -487,6 +503,10 @@ def search( to_return.append(result_for_query) + # Restore original ncells&ndocs if it had to be changed for large k values + self.searcher.configure(ncells=base_ncells) + self.searcher.configure(ndocs=base_ndocs) + if len(to_return) == 1: return to_return[0] return to_return