Skip to content

Commit

Permalink
feat: QdrantIndex (#1)
Browse files Browse the repository at this point in the history
* feat: QdrantIndex

* chore: poetry.lock
  • Loading branch information
Anush008 authored Mar 15, 2024
1 parent 931af0c commit 838767a
Show file tree
Hide file tree
Showing 6 changed files with 564 additions and 57 deletions.
236 changes: 235 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ torchvision = { version = "^0.17.0", optional = true}
pillow = { version= "^10.2.0", optional = true}
tiktoken = "^0.6.0"
matplotlib = { version="^3.8.3", optional = true}
qdrant-client = {version="^1.8.0", optional = true}

[tool.poetry.extras]
hybrid = ["pinecone-text"]
Expand All @@ -45,6 +46,7 @@ pinecone = ["pinecone-client"]
vision = ["torch", "torchvision", "transformers", "pillow"]
processing = ["matplotlib"]
mistralai = ["mistralai"]
qdrant = ["qdrant-client"]

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.0"
Expand Down
2 changes: 2 additions & 0 deletions semantic_router/index/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from semantic_router.index.base import BaseIndex
from semantic_router.index.local import LocalIndex
from semantic_router.index.pinecone import PineconeIndex
from semantic_router.index.qdrant import QdrantIndex

__all__ = [
"BaseIndex",
"LocalIndex",
"QdrantIndex",
"PineconeIndex",
]
223 changes: 223 additions & 0 deletions semantic_router/index/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
from pydantic.v1 import Field

from semantic_router.index.base import BaseIndex

DEFAULT_COLLECTION_NAME = "semantic-router-collection"
DEFAULT_UPLOAD_BATCH_SIZE = 100
SCROLL_SIZE = 1000


class QdrantIndex(BaseIndex):
"The name of the collection to use"

collection_name: str = Field(
default=DEFAULT_COLLECTION_NAME,
description=f"The name of the Qdrant collection to use. Defaults to '{DEFAULT_COLLECTION_NAME}'",
)
location: Optional[str] = Field(
default=":memory:",
description="If ':memory:' - use an in-memory Qdrant instance. Used as 'url' value otherwise",
)
url: Optional[str] = Field(
default=None,
description="Qualified URL of the Qdrant instance. Optional[scheme], host, Optional[port], Optional[prefix]",
)
port: Optional[int] = Field(
default=6333,
description="Port of the REST API interface.",
)
grpc_port: int = Field(
default=6334,
description="Port of the gRPC interface.",
)
prefer_grpc: bool = Field(
default=None,
description="Whether to use gPRC interface whenever possible in methods",
)
https: Optional[bool] = Field(
default=None,
description="Whether to use HTTPS(SSL) protocol.",
)
api_key: Optional[str] = Field(
default=None,
description="API key for authentication in Qdrant Cloud.",
)
prefix: Optional[str] = Field(
default=None,
description="Prefix to the REST URL path. Example: `http://localhost:6333/some/prefix/{qdrant-endpoint}`.",
)
timeout: Optional[int] = Field(
default=None,
description="Timeout for REST and gRPC API requests.",
)
host: Optional[str] = Field(
default=None,
description="Host name of Qdrant service. If url and host are None, set to 'localhost'.",
)
path: Optional[str] = Field(
default=None,
description="Persistence path for Qdrant local",
)
grpc_options: Optional[Dict[str, Any]] = Field(
default=None,
description="Options to be passed to the low-level Qdrant GRPC client, if used.",
)
size: Union[int, None] = Field(
default=None,
description="Embedding dimensions. Defaults to the embedding length of the configured encoder.",
)
distance: str = Field(
default="Cosine", description="Distance metric to use for similarity search."
)
collection_options: Optional[Dict[str, Any]] = Field(
default={},
description="Additonal options to be passed to `QdrantClient#create_collection`.",
)
client: Any = Field(default=None, exclude=True)

def __init__(self, **data):
super().__init__(**data)
self.type = "qdrant"
self.client = self._initialize_client()

def _initialize_client(self):
try:
from qdrant_client import QdrantClient

return QdrantClient(
location=self.location,
url=self.url,
port=self.port,
grpc_port=self.grpc_port,
prefer_grpc=self.prefer_grpc,
https=self.https,
api_key=self.api_key,
prefix=self.prefix,
timeout=self.timeout,
host=self.host,
path=self.path,
grpc_options=self.grpc_options,
)

except ImportError as e:
raise ImportError(

Check warning on line 106 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L105-L106

Added lines #L105 - L106 were not covered by tests
"Please install 'qdrant-client' to use QdrantIndex."
"You can install it with: "
"`pip install 'semantic-router[qdrant]'`"
) from e

def _init_collection(self) -> None:
from qdrant_client import QdrantClient, models

self.client: QdrantClient
if not self.client.collection_exists(self.collection_name):
if not self.dimensions:
raise ValueError(

Check warning on line 118 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L118

Added line #L118 was not covered by tests
"Cannot create a collection without specifying the dimensions."
)

self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(
size=self.dimensions,
distance=self.distance, # type: ignore
),
**self.collection_options,
)

def add(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE,
):
self.dimensions = self.dimensions or len(embeddings[0])
self._init_collection()

payloads = [
{"sr_route": route, "sr_utterance": utterance}
for route, utterance in zip(routes, utterances)
]

# UUIDs are autogenerated by qdrant-client if not provided explicitly
self.client.upload_collection(
self.collection_name,
vectors=embeddings,
payload=payloads,
batch_size=batch_size,
)

def get_routes(self) -> List[Tuple]:
"""
Gets a list of route and utterance objects currently stored in the index.
Returns:
List[Tuple]: A list of (route_name, utterance) objects.
"""

import grpc

Check warning on line 162 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L162

Added line #L162 was not covered by tests

results = []
next_offset = None
stop_scrolling = False
while not stop_scrolling:
records, next_offset = self.client.scroll(

Check warning on line 168 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L164-L168

Added lines #L164 - L168 were not covered by tests
self.collection_name,
limit=SCROLL_SIZE,
offset=next_offset,
with_payload=True,
)
stop_scrolling = next_offset is None or (

Check warning on line 174 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L174

Added line #L174 was not covered by tests
isinstance(next_offset, grpc.PointId)
and next_offset.num == 0
and next_offset.uuid == ""
)

results.extend(records)

Check warning on line 180 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L180

Added line #L180 was not covered by tests

route_tuples = [

Check warning on line 182 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L182

Added line #L182 was not covered by tests
(x.payload["sr_route"], x.payload["sr_utterance"]) for x in results
]
return route_tuples

Check warning on line 185 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L185

Added line #L185 was not covered by tests

def delete(self, route_name: str):
from qdrant_client import models

self.client.delete(
self.collection_name,
points_selector=models.Filter(
must=[
models.FieldCondition(
key="sr_route",
match=models.MatchText(text=route_name),
)
]
),
)

def describe(self) -> dict:
collection_info = self.client.get_collection(self.collection_name)

return {
"type": self.type,
"dimensions": collection_info.config.params.vectors.size,
"vectors": collection_info.points_count,
}

def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
results = self.client.search(
self.collection_name, query_vector=vector, limit=top_k, with_payload=True
)
scores = [result.score for result in results]
route_names = [result.payload["sr_route"] for result in results]
return np.array(scores), route_names

def delete_index(self):
self.client.delete_collection(self.collection_name)

Check warning on line 220 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L220

Added line #L220 was not covered by tests

def __len__(self):
return self.client.get_collection(self.collection_name).points_count
4 changes: 2 additions & 2 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ def from_yaml(cls, file_path: str):
return cls(encoder=encoder, routes=config.routes)

@classmethod
def from_config(cls, config: LayerConfig):
def from_config(cls, config: LayerConfig, index: Optional[BaseIndex] = None):
encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes)
return cls(encoder=encoder, routes=config.routes, index=index)

def add(self, route: Route):
logger.info(f"Adding `{route.name}` route")
Expand Down
Loading

0 comments on commit 838767a

Please sign in to comment.