-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add BM25 model and BLIP2 Captioned COCO Dataset
Bm25 and new dataset
- Loading branch information
Showing
12 changed files
with
1,982 additions
and
3 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
from .coco import COCODataset | ||
from .coco import COCODataset, COCODatasetBLIP2Captions, COCODatasetVLRMCaptions | ||
|
||
__all__ = ["COCODataset"] | ||
__all__ = [ | ||
"COCODataset", | ||
"COCODatasetBLIP2Captions", | ||
"COCODatasetVLRMCaptions", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import bm25s | ||
import numpy as np | ||
import Stemmer | ||
|
||
from xretrieval.models_registry import ModelRegistry | ||
|
||
|
||
@ModelRegistry.register("xhluca/bm25s", model_input="text") | ||
class BM25sModel: | ||
def __init__(self, model_id: str = "xhluca/bm25s"): | ||
self.model_id = model_id | ||
self.model = self.load_model() | ||
|
||
self.corpus_tokens = None | ||
self.stemmer = Stemmer.Stemmer("english") | ||
|
||
def load_model(self): | ||
return bm25s.BM25() | ||
|
||
def tokenize_text(self, text: list[str]): | ||
corpus_tokens = bm25s.tokenize(text, stopwords="en", stemmer=self.stemmer) | ||
self.model.index(corpus_tokens) | ||
self.corpus_tokens = corpus_tokens | ||
|
||
def retrieve(self, queries: list[str], top_k: int) -> np.ndarray: | ||
queries_tokens = bm25s.tokenize(queries, stopwords="en", stemmer=self.stemmer) | ||
results = self.model.retrieve( | ||
queries_tokens, k=top_k + 1 | ||
) # +1 for self-matches | ||
|
||
retrieved_ids = [] | ||
# Filter self matches for each query | ||
for idx, docs in enumerate(results.documents): | ||
filtered_docs = [doc for doc in docs if doc != idx][:top_k] | ||
retrieved_ids.append(filtered_docs) | ||
retrieved_ids = np.array(retrieved_ids) | ||
|
||
return retrieved_ids |