Skip to content

Commit

Permalink
Add BM25 model and BLIP2 Captioned COCO Dataset
Browse files Browse the repository at this point in the history
Bm25 and new dataset
  • Loading branch information
dnth authored Dec 3, 2024
2 parents 0482b74 + 607a1ed commit 93b0722
Show file tree
Hide file tree
Showing 12 changed files with 1,982 additions and 3 deletions.
354 changes: 354 additions & 0 deletions nbs/bm25_blip2-captions.ipynb

Large diffs are not rendered by default.

354 changes: 354 additions & 0 deletions nbs/bm25_coco-captions.ipynb

Large diffs are not rendered by default.

354 changes: 354 additions & 0 deletions nbs/bm25_vlrm-captions.ipynb

Large diffs are not rendered by default.

268 changes: 268 additions & 0 deletions nbs/cocoblip2dataset.ipynb

Large diffs are not rendered by default.

496 changes: 496 additions & 0 deletions nbs/runtime_comparison.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ dependencies = [
"sentence-transformers>=3.3.0",
"timm>=1.0.0",
"accelerate>=1.1.0",
"bm25s>=0.2.5",
"pystemmer>=2.2.0.3",
]

[build-system]
Expand Down
4 changes: 3 additions & 1 deletion xretrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
load_dataset,
load_model,
run_benchmark,
run_benchmark_bm25,
visualize_ground_truth,
visualize_retrieval,
)
from .datasets import COCODataset
from .datasets import COCODataset, COCODatasetBLIP2Captions, COCODatasetVLRMCaptions
from .datasets_registry import DatasetRegistry
from .models import SentenceTransformerModel
from .models_registry import ModelRegistry
Expand All @@ -22,6 +23,7 @@
"list_models",
"load_model",
"run_benchmark",
"run_benchmark_bm25",
"visualize_retrieval",
"visualize_ground_truth",
]
83 changes: 83 additions & 0 deletions xretrieval/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rich.table import Table

from .datasets_registry import DatasetRegistry
from .models.bm25 import BM25sModel
from .models_registry import ModelRegistry


Expand Down Expand Up @@ -71,6 +72,87 @@ def load_model(model_id: str):
return model_class(model_id=model_id)


def run_benchmark_bm25(dataset: str, top_k: int = 10):
logger.info("Running BM25 retrieval benchmark")
bm25_model = BM25sModel()
dataset = load_dataset(dataset)

logger.info("Tokenizing corpus")
corpus = dataset["caption"].tolist()
bm25_model.tokenize_text(corpus)

# Get labels for evaluation
image_ids = dataset.image_id.tolist()
image_ids = np.array(image_ids)
labels = dataset.loc[(dataset.image_id.isin(image_ids))].name.to_numpy()

logger.info("Performing retrieval")
retrieved_ids = bm25_model.retrieve(corpus, top_k=top_k)

logger.info("Calculating metrics")
matches = np.expand_dims(labels, axis=1) == labels[retrieved_ids]
matches = torch.tensor(np.array(matches), dtype=torch.float16)
targets = torch.ones(matches.shape)
indexes = (
torch.arange(matches.shape[0]).view(-1, 1)
* torch.ones(1, matches.shape[1]).long()
)

metrics = [
torchmetrics.retrieval.RetrievalMRR(),
torchmetrics.retrieval.RetrievalNormalizedDCG(),
torchmetrics.retrieval.RetrievalPrecision(),
torchmetrics.retrieval.RetrievalRecall(),
torchmetrics.retrieval.RetrievalHitRate(),
torchmetrics.retrieval.RetrievalMAP(),
]
eval_metrics_results = {}

for metr in metrics:
score = round(metr(targets, matches, indexes).item(), 4)
metr_name = metr.__class__.__name__.replace("Retrieval", "")
eval_metrics_results[metr_name] = score

table = Table(title=f"Retrieval Metrics @ k={top_k}")
table.add_column("Metric", style="cyan")
table.add_column("Score", style="magenta")

for metric_name, score in eval_metrics_results.items():
table.add_row(metric_name, f"{score:.4f}")

console = Console()
console.print(table)

# Create results DataFrame for visualization
results_data = []
for idx, retrieved in enumerate(retrieved_ids):
query_name = dataset.iloc[idx]["name"]
ground_truth_matches = dataset[
(dataset["name"] == query_name)
& (dataset["image_id"] != dataset.iloc[idx]["image_id"])
]

query_row = {
"query_id": dataset.iloc[idx]["image_id"],
"query_path": dataset.iloc[idx]["image_path"],
"query_caption": dataset.iloc[idx]["caption"],
"query_name": dataset.iloc[idx]["name"],
"retrieved_ids": [dataset.iloc[i]["image_id"] for i in retrieved],
"retrieved_paths": [dataset.iloc[i]["image_path"] for i in retrieved],
"retrieved_captions": [dataset.iloc[i]["caption"] for i in retrieved],
"retrieved_names": [dataset.iloc[i]["name"] for i in retrieved],
"is_correct": [labels[i] == labels[idx] for i in retrieved],
"ground_truth_ids": ground_truth_matches["image_id"].tolist(),
"ground_truth_paths": ground_truth_matches["image_path"].tolist(),
"ground_truth_captions": ground_truth_matches["caption"].tolist(),
}
results_data.append(query_row)

results_df = pd.DataFrame(results_data)

return eval_metrics_results, results_df


def run_benchmark(
dataset: str | pd.DataFrame,
model_id: str,
Expand All @@ -87,6 +169,7 @@ def run_benchmark(
top_k: Number of top results to retrieve (will retrieve top_k + 1 to account for self-matches)
"""
dataset = load_dataset(dataset)

# TODO: Dataset should contain columns ['image_id', 'file_name', 'image_path', 'caption', 'name']
model = load_model(model_id)
model_info = ModelRegistry.get_model_info(model_id)
Expand Down
8 changes: 6 additions & 2 deletions xretrieval/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .coco import COCODataset
from .coco import COCODataset, COCODatasetBLIP2Captions, COCODatasetVLRMCaptions

__all__ = ["COCODataset"]
__all__ = [
"COCODataset",
"COCODatasetBLIP2Captions",
"COCODatasetVLRMCaptions",
]
22 changes: 22 additions & 0 deletions xretrieval/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,25 @@ def get_dataset(self) -> pd.DataFrame:
)

return df


@DatasetRegistry.register(
"coco-val-2017-blip2-captions",
"The COCO Validation Set with 5k images and BLIP2 captions.",
)
class COCODatasetBLIP2Captions(COCODataset):
def load_annotations(self) -> pd.DataFrame:
url = "https://github.com/dnth/x.retrieval/releases/download/v0.1.1/blip2_captioned_coco_val_2017.parquet"
df = pd.read_parquet(url)
return df


@DatasetRegistry.register(
"coco-val-2017-vlrm-captions",
"The COCO Validation Set with 5k images and VLRM captions.",
)
class COCODatasetVLRMCaptions(COCODataset):
def load_annotations(self) -> pd.DataFrame:
url = "https://github.com/dnth/x.retrieval/releases/download/v0.1.1/vlrm_captioned_coco_val_2017.parquet"
df = pd.read_parquet(url)
return df
2 changes: 2 additions & 0 deletions xretrieval/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .blip2 import BLIP2ImageModel, BLIP2Model, BLIP2TextModel
from .bm25 import BM25sModel
from .sentence_transformers import SentenceTransformerModel
from .timm import TimmModel

Expand All @@ -8,4 +9,5 @@
"BLIP2Model",
"BLIP2TextModel",
"BLIP2ImageModel",
"BM25sModel",
]
38 changes: 38 additions & 0 deletions xretrieval/models/bm25.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import bm25s
import numpy as np
import Stemmer

from xretrieval.models_registry import ModelRegistry


@ModelRegistry.register("xhluca/bm25s", model_input="text")
class BM25sModel:
def __init__(self, model_id: str = "xhluca/bm25s"):
self.model_id = model_id
self.model = self.load_model()

self.corpus_tokens = None
self.stemmer = Stemmer.Stemmer("english")

def load_model(self):
return bm25s.BM25()

def tokenize_text(self, text: list[str]):
corpus_tokens = bm25s.tokenize(text, stopwords="en", stemmer=self.stemmer)
self.model.index(corpus_tokens)
self.corpus_tokens = corpus_tokens

def retrieve(self, queries: list[str], top_k: int) -> np.ndarray:
queries_tokens = bm25s.tokenize(queries, stopwords="en", stemmer=self.stemmer)
results = self.model.retrieve(
queries_tokens, k=top_k + 1
) # +1 for self-matches

retrieved_ids = []
# Filter self matches for each query
for idx, docs in enumerate(results.documents):
filtered_docs = [doc for doc in docs if doc != idx][:top_k]
retrieved_ids.append(filtered_docs)
retrieved_ids = np.array(retrieved_ids)

return retrieved_ids

0 comments on commit 93b0722

Please sign in to comment.