From 3c74deeb98629a58abcf32f9e2efc5c763c55fd8 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 8 Oct 2024 16:51:23 -0400 Subject: [PATCH] Fix sentence transformers reranker import (#231) Sentence transformers was not being dynamically imported into the reranker module for hugging face, causing dependency issues for anyone using a reranker. Fixes https://github.com/redis/redis-vl-python/issues/229 --- redisvl/utils/rerank/cohere.py | 2 +- redisvl/utils/rerank/hf_cross_encoder.py | 31 +++++++++++++++++++----- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/redisvl/utils/rerank/cohere.py b/redisvl/utils/rerank/cohere.py index 7b5c0927..87163c98 100644 --- a/redisvl/utils/rerank/cohere.py +++ b/redisvl/utils/rerank/cohere.py @@ -83,7 +83,7 @@ def _initialize_clients(self, api_config: Optional[Dict]): from cohere import AsyncClient, Client except ImportError: raise ImportError( - "Cohere vectorizer requires the cohere library. \ + "Cohere reranker requires the cohere library. \ Please install with `pip install cohere`" ) diff --git a/redisvl/utils/rerank/hf_cross_encoder.py b/redisvl/utils/rerank/hf_cross_encoder.py index 2fc0f908..65e323a8 100644 --- a/redisvl/utils/rerank/hf_cross_encoder.py +++ b/redisvl/utils/rerank/hf_cross_encoder.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union -from sentence_transformers import CrossEncoder +from pydantic.v1 import PrivateAttr from redisvl.utils.rerank.base import BaseReranker @@ -31,25 +31,44 @@ class HFCrossEncoderReranker(BaseReranker): ) """ + _client: Any = PrivateAttr() + def __init__( self, - model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", + model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", limit: int = 3, return_score: bool = True, + **kwargs, ) -> None: """ Initialize the HFCrossEncoderReranker with a specified model and ranking criteria. Parameters: - model_name (str): The name or path of the cross-encoder model to use for reranking. + model (str): The name or path of the cross-encoder model to use for reranking. Defaults to 'cross-encoder/ms-marco-MiniLM-L-6-v2'. limit (int): The maximum number of results to return after reranking. Must be a positive integer. return_score (bool): Whether to return scores alongside the reranked results. """ + model = model or kwargs.pop("model_name", None) super().__init__( - model=model_name, rank_by=None, limit=limit, return_score=return_score + model=model, rank_by=None, limit=limit, return_score=return_score ) - self.model: CrossEncoder = CrossEncoder(model_name) + self._initialize_client(**kwargs) + + def _initialize_client(self, **kwargs): + """ + Setup the huggingface cross-encoder client using optional kwargs. + """ + # Dynamic import of the sentence-transformers module + try: + from sentence_transformers import CrossEncoder + except ImportError: + raise ImportError( + "HFCrossEncoder reranker requires the sentence-transformers library. \ + Please install with `pip install sentence-transformers`" + ) + + self._client = CrossEncoder(self.model, **kwargs) def rank( self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs @@ -97,7 +116,7 @@ def rank( texts = [str(doc) for doc in docs] doc_subset = [{"content": doc} for doc in docs] - scores = self.model.predict([(query, text) for text in texts]) + scores = self._client.predict([(query, text) for text in texts]) scores = [float(score) for score in scores] docs_with_scores = list(zip(doc_subset, scores)) docs_with_scores.sort(key=lambda x: x[1], reverse=True)