Skip to content

Commit

Permalink
refactor: Addd metric enum
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Mar 18, 2024
1 parent a5703d5 commit 78a2dae
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions semantic_router/index/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic.v1 import Field

from semantic_router.index.base import BaseIndex
from semantic_router.schema import Metric

DEFAULT_COLLECTION_NAME = "semantic-router-index"
DEFAULT_UPLOAD_BATCH_SIZE = 100
Expand Down Expand Up @@ -71,8 +72,9 @@ class QdrantIndex(BaseIndex):
default=None,
description="Embedding dimensions. Defaults to the embedding length of the configured encoder.",
)
metric: str = Field(
default="Cosine", description="Distance metric to use for similarity search."
metric: Metric = Field(
default=Metric.COSINE,
description="Distance metric to use for similarity search.",
)
collection_options: Optional[Dict[str, Any]] = Field(
default={},
Expand Down Expand Up @@ -124,8 +126,7 @@ def _init_collection(self) -> None:
self.client.create_collection(
collection_name=self.index_name,
vectors_config=models.VectorParams(
size=self.dimensions,
distance=self.metric, # type: ignore
size=self.dimensions, distance=self.convert_metric(self.metric)
),
**self.collection_options,
)
Expand Down Expand Up @@ -222,5 +223,20 @@ def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[st
def delete_index(self):
self.client.delete_collection(self.index_name)

def convert_metric(self, metric: Metric):
from qdrant_client.models import Distance

mapping = {
Metric.COSINE: Distance.COSINE,
Metric.EUCLIDEAN: Distance.EUCLID,
Metric.DOTPRODUCT: Distance.DOT,
Metric.MANHATTAN: Distance.MANHATTAN,
}

if metric not in mapping:
raise ValueError(f"Unsupported Qdrant similarity metric: {metric}")

return mapping[metric]

def __len__(self):
return self.client.get_collection(self.index_name).points_count

0 comments on commit 78a2dae

Please sign in to comment.