Skip to content

Commit

Permalink
refactor: Addd metric enum
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Mar 18, 2024
1 parent a5703d5 commit c45b679
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
24 changes: 20 additions & 4 deletions semantic_router/index/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic.v1 import Field

from semantic_router.index.base import BaseIndex
from semantic_router.schema import Metric

DEFAULT_COLLECTION_NAME = "semantic-router-index"
DEFAULT_UPLOAD_BATCH_SIZE = 100
Expand Down Expand Up @@ -71,8 +72,9 @@ class QdrantIndex(BaseIndex):
default=None,
description="Embedding dimensions. Defaults to the embedding length of the configured encoder.",
)
metric: str = Field(
default="Cosine", description="Distance metric to use for similarity search."
metric: Metric = Field(
default=Metric.COSINE,
description="Distance metric to use for similarity search.",
)
collection_options: Optional[Dict[str, Any]] = Field(
default={},
Expand Down Expand Up @@ -124,8 +126,7 @@ def _init_collection(self) -> None:
self.client.create_collection(
collection_name=self.index_name,
vectors_config=models.VectorParams(
size=self.dimensions,
distance=self.metric, # type: ignore
size=self.dimensions, distance=self.convert_metric(self.metric)
),
**self.collection_options,
)
Expand Down Expand Up @@ -222,5 +223,20 @@ def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[st
def delete_index(self):
self.client.delete_collection(self.index_name)

Check warning on line 224 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L224

Added line #L224 was not covered by tests

def convert_metric(self, metric: Metric):
from qdrant_client.models import Distance

mapping = {
Metric.COSINE: Distance.COSINE,
Metric.EUCLIDEAN: Distance.EUCLID,
Metric.DOTPRODUCT: Distance.DOT,
Metric.MANHATTAN: Distance.MANHATTAN,
}

if metric not in mapping:
raise ValueError(f"Unsupported Qdrant similarity metric: {metric}")

Check warning on line 237 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L237

Added line #L237 was not covered by tests

return mapping[metric]

def __len__(self):
return self.client.get_collection(self.index_name).points_count
7 changes: 7 additions & 0 deletions semantic_router/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,10 @@ class DocumentSplit(BaseModel):
@property
def content(self) -> str:
return " ".join(self.docs)


class Metric(Enum):
COSINE = "cosine"
DOTPRODUCT = "dotproduct"
EUCLIDEAN = "euclidean"
MANHATTAN = "manhattan"

0 comments on commit c45b679

Please sign in to comment.