Skip to content

Commit

Permalink
Initial commit to review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
zahid-syed committed Mar 19, 2024
1 parent 931af0c commit f517a55
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 13 deletions.
7 changes: 6 additions & 1 deletion semantic_router/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ def describe(self) -> dict:
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
def query(
self,
vector: np.ndarray,
top_k: int = 5,
route_filter: Optional[List[str]] = None,
) -> Tuple[np.ndarray, List[str]]:
"""
Search the index for the query_vector and return top_k results.
This method should be implemented by subclasses.
Expand Down
30 changes: 24 additions & 6 deletions semantic_router/index/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,35 @@ def describe(self) -> dict:
"vectors": self.index.shape[0] if self.index is not None else 0,
}

def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
def query(
self,
vector: np.ndarray,
top_k: int = 5,
route_filter: Optional[List[str]] = None,
) -> Tuple[np.ndarray, List[str]]:
"""
Search the index for the query and return top_k results.
"""
if self.index is None or self.routes is None:
raise ValueError("Index or routes are not populated.")
sim = similarity_matrix(vector, self.index)
# extract the index values of top scoring vectors
scores, idx = top_scores(sim, top_k)
# get routes from index values
route_names = self.routes[idx].copy()
if route_filter is not None:
print(f"Filtering routes with filter: {route_filter}")
filtered_index = []
filtered_routes = []
for route, vec in zip(self.routes, self.index):
if route in route_filter:
filtered_index.append(vec)
filtered_routes.append(route)
if not filtered_routes:
raise ValueError("No routes found matching the filter criteria.")
sim = similarity_matrix(vector, np.array(filtered_index))
scores, idx = top_scores(sim, top_k)
route_names = [filtered_routes[i] for i in idx]

Check warning on line 82 in semantic_router/index/local.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/local.py#L71-L82

Added lines #L71 - L82 were not covered by tests
else:
sim = similarity_matrix(vector, self.index)
scores, idx = top_scores(sim, top_k)
route_names = [self.routes[i] for i in idx]
print(f"Routes considered for similarity calculation: {route_names}")
return scores, route_names

def delete(self, route_name: str):
Expand Down
14 changes: 13 additions & 1 deletion semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,29 @@ def describe(self) -> dict:
else:
raise ValueError("Index is None, cannot describe index stats.")

def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
def query(
self,
vector: np.ndarray,
top_k: int = 5,
route_filter: Optional[List[str]] = None,
) -> Tuple[np.ndarray, List[str]]:
if self.index is None:
raise ValueError("Index is not populated.")
query_vector_list = vector.tolist()
if route_filter is not None:
print(f"Filtering routes with filter: {route_filter}")
filter_query = {"sr_route": {"$in": route_filter}}

Check warning on line 233 in semantic_router/index/pinecone.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/pinecone.py#L231-L233

Added lines #L231 - L233 were not covered by tests
else:
filter_query = None

Check warning on line 235 in semantic_router/index/pinecone.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/pinecone.py#L235

Added line #L235 was not covered by tests
results = self.index.query(
vector=[query_vector_list],
top_k=top_k,
filter=filter_query,
include_metadata=True,
)
scores = [result["score"] for result in results["matches"]]
route_names = [result["metadata"]["sr_route"] for result in results["matches"]]
print(f"Routes considered for similarity calculation: {route_names}")

Check warning on line 244 in semantic_router/index/pinecone.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/pinecone.py#L244

Added line #L244 was not covered by tests
return np.array(scores), route_names

def delete_index(self):
Expand Down
18 changes: 13 additions & 5 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,16 @@ def __call__(
text: Optional[str] = None,
vector: Optional[List[float]] = None,
simulate_static: bool = False,
route_filter: Optional[List[str]] = None,
) -> RouteChoice:
# if no vector provided, encode text to get vector
if vector is None:
if text is None:
raise ValueError("Either text or vector must be provided")
vector = self._encode(text=text)

route, top_class_scores = self._retrieve_top_route(vector)
route, top_class_scores = self._retrieve_top_route(vector, route_filter)
print(f"Selected route: {route.name if route else 'None'}")
passed = self._check_threshold(top_class_scores, route)

if passed and route is not None and not simulate_static:
Expand Down Expand Up @@ -271,14 +273,16 @@ def __call__(
return RouteChoice()

def _retrieve_top_route(
self, vector: List[float]
self, vector: List[float], route_filter: Optional[List[str]] = None
) -> Tuple[Optional[Route], List[float]]:
"""
Retrieve the top matching route based on the given vector.
Returns a tuple of the route (if any) and the scores of the top class.
"""
# get relevant results (scores and routes)
results = self._retrieve(xq=np.array(vector), top_k=self.top_k)
results = self._retrieve(
xq=np.array(vector), top_k=self.top_k, route_filter=route_filter
)
# decide most relevant routes
top_class, top_class_scores = self._semantic_classify(results)
# TODO do we need this check?
Expand Down Expand Up @@ -397,10 +401,14 @@ def _encode(self, text: str) -> Any:
xq = np.squeeze(xq) # Reduce to 1d array.
return xq

def _retrieve(self, xq: Any, top_k: int = 5) -> List[dict]:
def _retrieve(
self, xq: Any, top_k: int = 5, route_filter: Optional[List[str]] = None
) -> List[dict]:
"""Given a query vector, retrieve the top_k most similar records."""
# get scores and routes
scores, routes = self.index.query(vector=xq, top_k=top_k)
scores, routes = self.index.query(
vector=xq, top_k=top_k, route_filter=route_filter
)
return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)]

def _set_aggregation_method(self, aggregation: str = "sum"):
Expand Down

0 comments on commit f517a55

Please sign in to comment.