Skip to content

Commit

Permalink
Merge pull request #355 from aurelio-labs/ashraq/sparse-vector
Browse files Browse the repository at this point in the history
feat: add sparse vector for PineconeIndex
  • Loading branch information
jamescalam authored Jul 16, 2024
2 parents ae08fab + 75a1799 commit a59e7d1
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "semantic-router"
version = "0.0.52"
version = "0.0.53"
description = "Super fast semantic router for AI decision making"
authors = [
"James Briggs <james@aurelio.ai>",
Expand Down
2 changes: 1 addition & 1 deletion semantic_router/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

__all__ = ["RouteLayer", "HybridRouteLayer", "Route", "LayerConfig"]

__version__ = "0.0.50"
__version__ = "0.0.53"
29 changes: 29 additions & 0 deletions semantic_router/index/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,35 @@ def query(
route_names = [self.routes[i] for i in idx]
return scores, route_names

async def aquery(
self,
vector: np.ndarray,
top_k: int = 5,
route_filter: Optional[List[str]] = None,
) -> Tuple[np.ndarray, List[str]]:
"""
Search the index for the query and return top_k results.
"""
if self.index is None or self.routes is None:
raise ValueError("Index or routes are not populated.")
if route_filter is not None:
filtered_index = []
filtered_routes = []
for route, vec in zip(self.routes, self.index):
if route in route_filter:
filtered_index.append(vec)
filtered_routes.append(route)
if not filtered_routes:
raise ValueError("No routes found matching the filter criteria.")
sim = similarity_matrix(vector, np.array(filtered_index))
scores, idx = top_scores(sim, top_k)
route_names = [filtered_routes[i] for i in idx]
else:
sim = similarity_matrix(vector, self.index)
scores, idx = top_scores(sim, top_k)
route_names = [self.routes[i] for i in idx]
return scores, route_names

def delete(self, route_name: str):
"""
Delete all records of a specific route from the index.
Expand Down
40 changes: 40 additions & 0 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,25 @@ def query(
vector: np.ndarray,
top_k: int = 5,
route_filter: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[np.ndarray, List[str]]:
"""
Search the index for the query vector and return the top_k results.
:param vector: The query vector to search for.
:type vector: np.ndarray
:param top_k: The number of top results to return, defaults to 5.
:type top_k: int, optional
:param route_filter: A list of route names to filter the search results, defaults to None.
:type route_filter: Optional[List[str]], optional
:param kwargs: Additional keyword arguments for the query, including sparse_vector.
:type kwargs: Any
:keyword sparse_vector: An optional sparse vector to include in the query.
:type sparse_vector: Optional[dict]
:return: A tuple containing an array of scores and a list of route names.
:rtype: Tuple[np.ndarray, List[str]]
:raises ValueError: If the index is not populated.
"""
if self.index is None:
raise ValueError("Index is not populated.")
query_vector_list = vector.tolist()
Expand All @@ -474,6 +492,7 @@ def query(
filter_query = None
results = self.index.query(
vector=[query_vector_list],
sparse_vector=kwargs.get("sparse_vector", None),
top_k=top_k,
filter=filter_query,
include_metadata=True,
Expand All @@ -488,7 +507,25 @@ async def aquery(
vector: np.ndarray,
top_k: int = 5,
route_filter: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[np.ndarray, List[str]]:
"""
Asynchronously search the index for the query vector and return the top_k results.
:param vector: The query vector to search for.
:type vector: np.ndarray
:param top_k: The number of top results to return, defaults to 5.
:type top_k: int, optional
:param route_filter: A list of route names to filter the search results, defaults to None.
:type route_filter: Optional[List[str]], optional
:param kwargs: Additional keyword arguments for the query, including sparse_vector.
:type kwargs: Any
:keyword sparse_vector: An optional sparse vector to include in the query.
:type sparse_vector: Optional[dict]
:return: A tuple containing an array of scores and a list of route names.
:rtype: Tuple[np.ndarray, List[str]]
:raises ValueError: If the index is not populated.
"""
if self.async_client is None or self.host is None:
raise ValueError("Async client or host are not initialized.")
query_vector_list = vector.tolist()
Expand All @@ -498,6 +535,7 @@ async def aquery(
filter_query = None
results = await self._async_query(
vector=query_vector_list,
sparse_vector=kwargs.get("sparse_vector", None),
namespace=self.namespace or "",
filter=filter_query,
top_k=top_k,
Expand All @@ -514,13 +552,15 @@ def delete_index(self):
async def _async_query(
self,
vector: list[float],
sparse_vector: Optional[dict] = None,
namespace: str = "",
filter: Optional[dict] = None,
top_k: int = 5,
include_metadata: bool = False,
):
params = {
"vector": vector,
"sparse_vector": sparse_vector,
"namespace": namespace,
"filter": filter,
"top_k": top_k,
Expand Down

0 comments on commit a59e7d1

Please sign in to comment.