Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: QdrantIndex #207

Merged
merged 5 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 235 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ torchvision = { version = "^0.17.0", optional = true}
pillow = { version= "^10.2.0", optional = true}
tiktoken = "^0.6.0"
matplotlib = { version="^3.8.3", optional = true}
qdrant-client = {version="^1.8.0", optional = true}

[tool.poetry.extras]
hybrid = ["pinecone-text"]
Expand All @@ -45,6 +46,7 @@ pinecone = ["pinecone-client"]
vision = ["torch", "torchvision", "transformers", "pillow"]
processing = ["matplotlib"]
mistralai = ["mistralai"]
qdrant = ["qdrant-client"]

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.0"
Expand Down
2 changes: 2 additions & 0 deletions semantic_router/index/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from semantic_router.index.base import BaseIndex
from semantic_router.index.local import LocalIndex
from semantic_router.index.pinecone import PineconeIndex
from semantic_router.index.qdrant import QdrantIndex

__all__ = [
"BaseIndex",
"LocalIndex",
"QdrantIndex",
"PineconeIndex",
]
247 changes: 247 additions & 0 deletions semantic_router/index/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
from pydantic.v1 import Field

from semantic_router.index.base import BaseIndex
from semantic_router.schema import Metric

DEFAULT_COLLECTION_NAME = "semantic-router-index"
DEFAULT_UPLOAD_BATCH_SIZE = 100
SCROLL_SIZE = 1000
SR_UTTERANCE_PAYLOAD_KEY = "sr_utterance"
SR_ROUTE_PAYLOAD_KEY = "sr_route"


class QdrantIndex(BaseIndex):
"The name of the collection to use"

index_name: str = Field(
default=DEFAULT_COLLECTION_NAME,
description="Name of the Qdrant collection."
f"Default: '{DEFAULT_COLLECTION_NAME}'",
)
location: Optional[str] = Field(
default=":memory:",
description="If ':memory:' - use an in-memory Qdrant instance."
"Used as 'url' value otherwise",
)
url: Optional[str] = Field(
default=None,
description="Qualified URL of the Qdrant instance."
"Optional[scheme], host, Optional[port], Optional[prefix]",
)
port: Optional[int] = Field(
default=6333,
description="Port of the REST API interface.",
)
grpc_port: int = Field(
default=6334,
description="Port of the gRPC interface.",
)
prefer_grpc: bool = Field(
default=None,
description="Whether to use gPRC interface whenever possible in methods",
)
https: Optional[bool] = Field(
default=None,
description="Whether to use HTTPS(SSL) protocol.",
)
api_key: Optional[str] = Field(
default=None,
description="API key for authentication in Qdrant Cloud.",
)
prefix: Optional[str] = Field(
default=None,
description="Prefix to the REST URL path. Example: `http://localhost:6333/some/prefix/{qdrant-endpoint}`.",
)
timeout: Optional[int] = Field(
default=None,
description="Timeout for REST and gRPC API requests.",
)
host: Optional[str] = Field(
default=None,
description="Host name of Qdrant service."
"If url and host are None, set to 'localhost'.",
)
path: Optional[str] = Field(
default=None,
description="Persistence path for Qdrant local",
)
grpc_options: Optional[Dict[str, Any]] = Field(
default=None,
description="Options to be passed to the low-level GRPC client, if used.",
)
dimensions: Union[int, None] = Field(
default=None,
description="Embedding dimensions."
"Defaults to the embedding length of the configured encoder.",
)
metric: Metric = Field(
default=Metric.COSINE,
description="Distance metric to use for similarity search.",
)
config: Optional[Dict[str, Any]] = Field(
default={},
description="Collection options passed to `QdrantClient#create_collection`.",
)
client: Any = Field(default=None, exclude=True)

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.type = "qdrant"
self.client = self._initialize_client()

def _initialize_client(self):
try:
from qdrant_client import QdrantClient

return QdrantClient(
location=self.location,
url=self.url,
port=self.port,
grpc_port=self.grpc_port,
prefer_grpc=self.prefer_grpc,
https=self.https,
api_key=self.api_key,
prefix=self.prefix,
timeout=self.timeout,
host=self.host,
path=self.path,
grpc_options=self.grpc_options,
)

except ImportError as e:
raise ImportError(

Check warning on line 115 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L114-L115

Added lines #L114 - L115 were not covered by tests
"Please install 'qdrant-client' to use QdrantIndex."
"You can install it with: "
"`pip install 'semantic-router[qdrant]'`"
) from e

def _init_collection(self) -> None:
from qdrant_client import QdrantClient, models

self.client: QdrantClient
if not self.client.collection_exists(self.index_name):
if not self.dimensions:
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(

Check warning on line 127 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L127

Added line #L127 was not covered by tests
"Cannot create a collection without specifying the dimensions."
)

self.client.create_collection(
collection_name=self.index_name,
vectors_config=models.VectorParams(
size=self.dimensions, distance=self.convert_metric(self.metric)
),
jamescalam marked this conversation as resolved.
Show resolved Hide resolved
**self.config,
)

def add(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE,
):
self.dimensions = self.dimensions or len(embeddings[0])
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
self._init_collection()

payloads = [
{SR_ROUTE_PAYLOAD_KEY: route, SR_UTTERANCE_PAYLOAD_KEY: utterance}
for route, utterance in zip(routes, utterances)
]

# UUIDs are autogenerated by qdrant-client if not provided explicitly
self.client.upload_collection(
self.index_name,
vectors=embeddings,
payload=payloads,
batch_size=batch_size,
)

def get_routes(self) -> List[Tuple]:
"""
Gets a list of route and utterance objects currently stored in the index.

Returns:
List[Tuple]: A list of (route_name, utterance) objects.
"""

import grpc

Check warning on line 170 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L170

Added line #L170 was not covered by tests

results = []
next_offset = None
stop_scrolling = False
while not stop_scrolling:
records, next_offset = self.client.scroll(

Check warning on line 176 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L172-L176

Added lines #L172 - L176 were not covered by tests
self.index_name,
limit=SCROLL_SIZE,
offset=next_offset,
with_payload=True,
)
stop_scrolling = next_offset is None or (

Check warning on line 182 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L182

Added line #L182 was not covered by tests
isinstance(next_offset, grpc.PointId)
and next_offset.num == 0
and next_offset.uuid == ""
)

results.extend(records)

Check warning on line 188 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L188

Added line #L188 was not covered by tests

route_tuples = [

Check warning on line 190 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L190

Added line #L190 was not covered by tests
(x.payload[SR_ROUTE_PAYLOAD_KEY], x.payload[SR_UTTERANCE_PAYLOAD_KEY])
for x in results
]
return route_tuples

Check warning on line 194 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L194

Added line #L194 was not covered by tests

def delete(self, route_name: str):
from qdrant_client import models

self.client.delete(
self.index_name,
points_selector=models.Filter(
must=[
models.FieldCondition(
key=SR_ROUTE_PAYLOAD_KEY,
match=models.MatchText(text=route_name),
)
]
),
)

def describe(self) -> dict:
collection_info = self.client.get_collection(self.index_name)

return {
"type": self.type,
"dimensions": collection_info.config.params.vectors.size,
"vectors": collection_info.points_count,
}

def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
results = self.client.search(
self.index_name, query_vector=vector, limit=top_k, with_payload=True
)
scores = [result.score for result in results]
route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results]
return np.array(scores), route_names

def delete_index(self):
self.client.delete_collection(self.index_name)

Check warning on line 229 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L229

Added line #L229 was not covered by tests

def convert_metric(self, metric: Metric):
from qdrant_client.models import Distance

mapping = {
Metric.COSINE: Distance.COSINE,
Metric.EUCLIDEAN: Distance.EUCLID,
Metric.DOTPRODUCT: Distance.DOT,
Metric.MANHATTAN: Distance.MANHATTAN,
}

if metric not in mapping:
raise ValueError(f"Unsupported Qdrant similarity metric: {metric}")

Check warning on line 242 in semantic_router/index/qdrant.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/qdrant.py#L242

Added line #L242 was not covered by tests

return mapping[metric]

def __len__(self):
return self.client.get_collection(self.index_name).points_count
4 changes: 2 additions & 2 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ def from_yaml(cls, file_path: str):
return cls(encoder=encoder, routes=config.routes)

@classmethod
def from_config(cls, config: LayerConfig):
def from_config(cls, config: LayerConfig, index: Optional[BaseIndex] = None):
encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes)
return cls(encoder=encoder, routes=config.routes, index=index)

def add(self, route: Route):
logger.info(f"Adding `{route.name}` route")
Expand Down
7 changes: 7 additions & 0 deletions semantic_router/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,10 @@ class DocumentSplit(BaseModel):
@property
def content(self) -> str:
return " ".join(self.docs)


class Metric(Enum):
COSINE = "cosine"
DOTPRODUCT = "dotproduct"
EUCLIDEAN = "euclidean"
MANHATTAN = "manhattan"
Loading
Loading