Skip to content

Commit

Permalink
Merge branch 'main' into add_spacy
Browse files Browse the repository at this point in the history
  • Loading branch information
mesax1 committed Mar 21, 2024
2 parents dc339b1 + 9c38431 commit bb378c1
Show file tree
Hide file tree
Showing 7 changed files with 597 additions and 57 deletions.
236 changes: 235 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ torchvision = { version = "^0.17.0", optional = true}
pillow = { version= "^10.2.0", optional = true}
tiktoken = "^0.6.0"
matplotlib = { version="^3.8.3", optional = true}
qdrant-client = {version="^1.8.0", optional = true}


[tool.poetry.extras]
Expand All @@ -47,6 +48,7 @@ pinecone = ["pinecone-client"]
vision = ["torch", "torchvision", "transformers", "pillow"]
processing = ["matplotlib", "spacy"]
mistralai = ["mistralai"]
qdrant = ["qdrant-client"]

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.0"
Expand Down
2 changes: 2 additions & 0 deletions semantic_router/index/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from semantic_router.index.base import BaseIndex
from semantic_router.index.local import LocalIndex
from semantic_router.index.pinecone import PineconeIndex
from semantic_router.index.qdrant import QdrantIndex

__all__ = [
"BaseIndex",
"LocalIndex",
"QdrantIndex",
"PineconeIndex",
]
247 changes: 247 additions & 0 deletions semantic_router/index/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
from pydantic.v1 import Field

from semantic_router.index.base import BaseIndex
from semantic_router.schema import Metric

DEFAULT_COLLECTION_NAME = "semantic-router-index"
DEFAULT_UPLOAD_BATCH_SIZE = 100
SCROLL_SIZE = 1000
SR_UTTERANCE_PAYLOAD_KEY = "sr_utterance"
SR_ROUTE_PAYLOAD_KEY = "sr_route"


class QdrantIndex(BaseIndex):
"The name of the collection to use"

index_name: str = Field(
default=DEFAULT_COLLECTION_NAME,
description="Name of the Qdrant collection."
f"Default: '{DEFAULT_COLLECTION_NAME}'",
)
location: Optional[str] = Field(
default=":memory:",
description="If ':memory:' - use an in-memory Qdrant instance."
"Used as 'url' value otherwise",
)
url: Optional[str] = Field(
default=None,
description="Qualified URL of the Qdrant instance."
"Optional[scheme], host, Optional[port], Optional[prefix]",
)
port: Optional[int] = Field(
default=6333,
description="Port of the REST API interface.",
)
grpc_port: int = Field(
default=6334,
description="Port of the gRPC interface.",
)
prefer_grpc: bool = Field(
default=None,
description="Whether to use gPRC interface whenever possible in methods",
)
https: Optional[bool] = Field(
default=None,
description="Whether to use HTTPS(SSL) protocol.",
)
api_key: Optional[str] = Field(
default=None,
description="API key for authentication in Qdrant Cloud.",
)
prefix: Optional[str] = Field(
default=None,
description="Prefix to the REST URL path. Example: `http://localhost:6333/some/prefix/{qdrant-endpoint}`.",
)
timeout: Optional[int] = Field(
default=None,
description="Timeout for REST and gRPC API requests.",
)
host: Optional[str] = Field(
default=None,
description="Host name of Qdrant service."
"If url and host are None, set to 'localhost'.",
)
path: Optional[str] = Field(
default=None,
description="Persistence path for Qdrant local",
)
grpc_options: Optional[Dict[str, Any]] = Field(
default=None,
description="Options to be passed to the low-level GRPC client, if used.",
)
dimensions: Union[int, None] = Field(
default=None,
description="Embedding dimensions."
"Defaults to the embedding length of the configured encoder.",
)
metric: Metric = Field(
default=Metric.COSINE,
description="Distance metric to use for similarity search.",
)
config: Optional[Dict[str, Any]] = Field(
default={},
description="Collection options passed to `QdrantClient#create_collection`.",
)
client: Any = Field(default=None, exclude=True)

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.type = "qdrant"
self.client = self._initialize_client()

def _initialize_client(self):
try:
from qdrant_client import QdrantClient

return QdrantClient(
location=self.location,
url=self.url,
port=self.port,
grpc_port=self.grpc_port,
prefer_grpc=self.prefer_grpc,
https=self.https,
api_key=self.api_key,
prefix=self.prefix,
timeout=self.timeout,
host=self.host,
path=self.path,
grpc_options=self.grpc_options,
)

except ImportError as e:
raise ImportError(
"Please install 'qdrant-client' to use QdrantIndex."
"You can install it with: "
"`pip install 'semantic-router[qdrant]'`"
) from e

def _init_collection(self) -> None:
from qdrant_client import QdrantClient, models

self.client: QdrantClient
if not self.client.collection_exists(self.index_name):
if not self.dimensions:
raise ValueError(
"Cannot create a collection without specifying the dimensions."
)

self.client.create_collection(
collection_name=self.index_name,
vectors_config=models.VectorParams(
size=self.dimensions, distance=self.convert_metric(self.metric)
),
**self.config,
)

def add(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE,
):
self.dimensions = self.dimensions or len(embeddings[0])
self._init_collection()

payloads = [
{SR_ROUTE_PAYLOAD_KEY: route, SR_UTTERANCE_PAYLOAD_KEY: utterance}
for route, utterance in zip(routes, utterances)
]

# UUIDs are autogenerated by qdrant-client if not provided explicitly
self.client.upload_collection(
self.index_name,
vectors=embeddings,
payload=payloads,
batch_size=batch_size,
)

def get_routes(self) -> List[Tuple]:
"""
Gets a list of route and utterance objects currently stored in the index.
Returns:
List[Tuple]: A list of (route_name, utterance) objects.
"""

import grpc

results = []
next_offset = None
stop_scrolling = False
while not stop_scrolling:
records, next_offset = self.client.scroll(
self.index_name,
limit=SCROLL_SIZE,
offset=next_offset,
with_payload=True,
)
stop_scrolling = next_offset is None or (
isinstance(next_offset, grpc.PointId)
and next_offset.num == 0
and next_offset.uuid == ""
)

results.extend(records)

route_tuples = [
(x.payload[SR_ROUTE_PAYLOAD_KEY], x.payload[SR_UTTERANCE_PAYLOAD_KEY])
for x in results
]
return route_tuples

def delete(self, route_name: str):
from qdrant_client import models

self.client.delete(
self.index_name,
points_selector=models.Filter(
must=[
models.FieldCondition(
key=SR_ROUTE_PAYLOAD_KEY,
match=models.MatchText(text=route_name),
)
]
),
)

def describe(self) -> dict:
collection_info = self.client.get_collection(self.index_name)

return {
"type": self.type,
"dimensions": collection_info.config.params.vectors.size,
"vectors": collection_info.points_count,
}

def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
results = self.client.search(
self.index_name, query_vector=vector, limit=top_k, with_payload=True
)
scores = [result.score for result in results]
route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results]
return np.array(scores), route_names

def delete_index(self):
self.client.delete_collection(self.index_name)

def convert_metric(self, metric: Metric):
from qdrant_client.models import Distance

mapping = {
Metric.COSINE: Distance.COSINE,
Metric.EUCLIDEAN: Distance.EUCLID,
Metric.DOTPRODUCT: Distance.DOT,
Metric.MANHATTAN: Distance.MANHATTAN,
}

if metric not in mapping:
raise ValueError(f"Unsupported Qdrant similarity metric: {metric}")

return mapping[metric]

def __len__(self):
return self.client.get_collection(self.index_name).points_count
4 changes: 2 additions & 2 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ def from_yaml(cls, file_path: str):
return cls(encoder=encoder, routes=config.routes)

@classmethod
def from_config(cls, config: LayerConfig):
def from_config(cls, config: LayerConfig, index: Optional[BaseIndex] = None):
encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes)
return cls(encoder=encoder, routes=config.routes, index=index)

def add(self, route: Route):
logger.info(f"Adding `{route.name}` route")
Expand Down
7 changes: 7 additions & 0 deletions semantic_router/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,10 @@ class DocumentSplit(BaseModel):
@property
def content(self) -> str:
return " ".join(self.docs)


class Metric(Enum):
COSINE = "cosine"
DOTPRODUCT = "dotproduct"
EUCLIDEAN = "euclidean"
MANHATTAN = "manhattan"
Loading

0 comments on commit bb378c1

Please sign in to comment.