From 78a2daec28663d2025d7d6cf40e44300786f58b6 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Mon, 18 Mar 2024 16:14:04 +0530 Subject: [PATCH] refactor: Addd metric enum --- semantic_router/index/qdrant.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index cc1ddaef..e112fb35 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -4,6 +4,7 @@ from pydantic.v1 import Field from semantic_router.index.base import BaseIndex +from semantic_router.schema import Metric DEFAULT_COLLECTION_NAME = "semantic-router-index" DEFAULT_UPLOAD_BATCH_SIZE = 100 @@ -71,8 +72,9 @@ class QdrantIndex(BaseIndex): default=None, description="Embedding dimensions. Defaults to the embedding length of the configured encoder.", ) - metric: str = Field( - default="Cosine", description="Distance metric to use for similarity search." + metric: Metric = Field( + default=Metric.COSINE, + description="Distance metric to use for similarity search.", ) collection_options: Optional[Dict[str, Any]] = Field( default={}, @@ -124,8 +126,7 @@ def _init_collection(self) -> None: self.client.create_collection( collection_name=self.index_name, vectors_config=models.VectorParams( - size=self.dimensions, - distance=self.metric, # type: ignore + size=self.dimensions, distance=self.convert_metric(self.metric) ), **self.collection_options, ) @@ -222,5 +223,20 @@ def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[st def delete_index(self): self.client.delete_collection(self.index_name) + def convert_metric(self, metric: Metric): + from qdrant_client.models import Distance + + mapping = { + Metric.COSINE: Distance.COSINE, + Metric.EUCLIDEAN: Distance.EUCLID, + Metric.DOTPRODUCT: Distance.DOT, + Metric.MANHATTAN: Distance.MANHATTAN, + } + + if metric not in mapping: + raise ValueError(f"Unsupported Qdrant similarity metric: {metric}") + + return mapping[metric] + def __len__(self): return self.client.get_collection(self.index_name).points_count