Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: QdrantIndex #207

Merged
merged 5 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 235 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ torchvision = { version = "^0.17.0", optional = true}
pillow = { version= "^10.2.0", optional = true}
tiktoken = "^0.6.0"
matplotlib = { version="^3.8.3", optional = true}
qdrant-client = {version="^1.8.0", optional = true}

[tool.poetry.extras]
hybrid = ["pinecone-text"]
Expand All @@ -45,6 +46,7 @@ pinecone = ["pinecone-client"]
vision = ["torch", "torchvision", "transformers", "pillow"]
processing = ["matplotlib"]
mistralai = ["mistralai"]
qdrant = ["qdrant-client"]

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.0"
Expand Down
2 changes: 2 additions & 0 deletions semantic_router/index/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from semantic_router.index.base import BaseIndex
from semantic_router.index.local import LocalIndex
from semantic_router.index.pinecone import PineconeIndex
from semantic_router.index.qdrant import QdrantIndex

__all__ = [
"BaseIndex",
"LocalIndex",
"QdrantIndex",
"PineconeIndex",
]
226 changes: 226 additions & 0 deletions semantic_router/index/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
from pydantic.v1 import Field

from semantic_router.index.base import BaseIndex

DEFAULT_COLLECTION_NAME = "semantic-router-collection"
DEFAULT_UPLOAD_BATCH_SIZE = 100
SCROLL_SIZE = 1000
SR_UTTERANCE_PAYLOAD_KEY = "sr_utterance"
SR_ROUTE_PAYLOAD_KEY = "sr_route"


class QdrantIndex(BaseIndex):
"The name of the collection to use"

collection_name: str = Field(
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
default=DEFAULT_COLLECTION_NAME,
description=f"The name of the Qdrant collection to use. Defaults to '{DEFAULT_COLLECTION_NAME}'",
)
location: Optional[str] = Field(
default=":memory:",
description="If ':memory:' - use an in-memory Qdrant instance. Used as 'url' value otherwise",
)
url: Optional[str] = Field(
default=None,
description="Qualified URL of the Qdrant instance. Optional[scheme], host, Optional[port], Optional[prefix]",
)
port: Optional[int] = Field(
default=6333,
description="Port of the REST API interface.",
)
grpc_port: int = Field(
default=6334,
description="Port of the gRPC interface.",
)
prefer_grpc: bool = Field(
default=None,
description="Whether to use gPRC interface whenever possible in methods",
)
https: Optional[bool] = Field(
default=None,
description="Whether to use HTTPS(SSL) protocol.",
)
api_key: Optional[str] = Field(
default=None,
description="API key for authentication in Qdrant Cloud.",
)
prefix: Optional[str] = Field(
default=None,
description="Prefix to the REST URL path. Example: `http://localhost:6333/some/prefix/{qdrant-endpoint}`.",
)
timeout: Optional[int] = Field(
default=None,
description="Timeout for REST and gRPC API requests.",
)
host: Optional[str] = Field(
default=None,
description="Host name of Qdrant service. If url and host are None, set to 'localhost'.",
)
path: Optional[str] = Field(
default=None,
description="Persistence path for Qdrant local",
)
grpc_options: Optional[Dict[str, Any]] = Field(
default=None,
description="Options to be passed to the low-level Qdrant GRPC client, if used.",
)
size: Union[int, None] = Field(
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
default=None,
description="Embedding dimensions. Defaults to the embedding length of the configured encoder.",
)
distance: str = Field(
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
default="Cosine", description="Distance metric to use for similarity search."
)
collection_options: Optional[Dict[str, Any]] = Field(
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
default={},
description="Additonal options to be passed to `QdrantClient#create_collection`.",
)
client: Any = Field(default=None, exclude=True)

def __init__(self, **data):
super().__init__(**data)
self.type = "qdrant"
self.client = self._initialize_client()

def _initialize_client(self):
try:
from qdrant_client import QdrantClient

return QdrantClient(
location=self.location,
url=self.url,
port=self.port,
grpc_port=self.grpc_port,
prefer_grpc=self.prefer_grpc,
https=self.https,
api_key=self.api_key,
prefix=self.prefix,
timeout=self.timeout,
host=self.host,
path=self.path,
grpc_options=self.grpc_options,
)

except ImportError as e:
raise ImportError(

Check warning on line 108 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L107-L108

Added lines #L107 - L108 were not covered by tests
"Please install 'qdrant-client' to use QdrantIndex."
"You can install it with: "
"`pip install 'semantic-router[qdrant]'`"
) from e

def _init_collection(self) -> None:
from qdrant_client import QdrantClient, models

self.client: QdrantClient
if not self.client.collection_exists(self.collection_name):
if not self.dimensions:
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(

Check warning on line 120 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L120

Added line #L120 was not covered by tests
"Cannot create a collection without specifying the dimensions."
)

self.client.create_collection(
collection_name=self.collection_name,
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
vectors_config=models.VectorParams(
size=self.dimensions,
distance=self.distance, # type: ignore
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
),
jamescalam marked this conversation as resolved.
Show resolved Hide resolved
**self.collection_options,
)

def add(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE,
):
self.dimensions = self.dimensions or len(embeddings[0])
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
self._init_collection()

payloads = [
{SR_ROUTE_PAYLOAD_KEY: route, SR_UTTERANCE_PAYLOAD_KEY: utterance}
for route, utterance in zip(routes, utterances)
]

# UUIDs are autogenerated by qdrant-client if not provided explicitly
self.client.upload_collection(
self.collection_name,
vectors=embeddings,
payload=payloads,
batch_size=batch_size,
)

def get_routes(self) -> List[Tuple]:
"""
Gets a list of route and utterance objects currently stored in the index.

Returns:
List[Tuple]: A list of (route_name, utterance) objects.
"""

import grpc

Check warning on line 164 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L164

Added line #L164 was not covered by tests

results = []
next_offset = None
stop_scrolling = False
while not stop_scrolling:
records, next_offset = self.client.scroll(

Check warning on line 170 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L166-L170

Added lines #L166 - L170 were not covered by tests
self.collection_name,
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
limit=SCROLL_SIZE,
offset=next_offset,
with_payload=True,
)
stop_scrolling = next_offset is None or (

Check warning on line 176 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L176

Added line #L176 was not covered by tests
isinstance(next_offset, grpc.PointId)
and next_offset.num == 0
and next_offset.uuid == ""
)

results.extend(records)

Check warning on line 182 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L182

Added line #L182 was not covered by tests

route_tuples = [

Check warning on line 184 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L184

Added line #L184 was not covered by tests
(x.payload[SR_ROUTE_PAYLOAD_KEY], x.payload[SR_UTTERANCE_PAYLOAD_KEY])
for x in results
]
return route_tuples

Check warning on line 188 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L188

Added line #L188 was not covered by tests

def delete(self, route_name: str):
from qdrant_client import models

self.client.delete(
self.collection_name,
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
points_selector=models.Filter(
must=[
models.FieldCondition(
key=SR_ROUTE_PAYLOAD_KEY,
match=models.MatchText(text=route_name),
)
]
),
)

def describe(self) -> dict:
collection_info = self.client.get_collection(self.collection_name)
Anush008 marked this conversation as resolved.
Show resolved Hide resolved

return {
"type": self.type,
"dimensions": collection_info.config.params.vectors.size,
"vectors": collection_info.points_count,
}

def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
results = self.client.search(
self.collection_name, query_vector=vector, limit=top_k, with_payload=True
)
scores = [result.score for result in results]
route_names = [result.payload["sr_route"] for result in results]
return np.array(scores), route_names

def delete_index(self):
self.client.delete_collection(self.collection_name)

Check warning on line 223 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L223

Added line #L223 was not covered by tests
Anush008 marked this conversation as resolved.
Show resolved Hide resolved

def __len__(self):
return self.client.get_collection(self.collection_name).points_count
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 2 additions & 2 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ def from_yaml(cls, file_path: str):
return cls(encoder=encoder, routes=config.routes)

@classmethod
def from_config(cls, config: LayerConfig):
def from_config(cls, config: LayerConfig, index: Optional[BaseIndex] = None):
encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes)
return cls(encoder=encoder, routes=config.routes, index=index)

def add(self, route: Route):
logger.info(f"Adding `{route.name}` route")
Expand Down
Loading
Loading