Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
zahid-syed committed Mar 12, 2024
2 parents ffc5f88 + 43e662d commit fcf7077
Showing 1 changed file with 44 additions and 19 deletions.
63 changes: 44 additions & 19 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,29 +220,18 @@ def __call__(
self,
text: Optional[str] = None,
vector: Optional[List[float]] = None,
simulate_static: bool = False,
) -> RouteChoice:
# if no vector provided, encode text to get vector
if vector is None:
if text is None:
raise ValueError("Either text or vector must be provided")
vector_arr = self._encode(text=text)
else:
vector_arr = np.array(vector)
# get relevant results (scores and routes)
results = self._retrieve(xq=vector_arr)
# decide most relevant routes
top_class, top_class_scores = self._semantic_classify(results)
# TODO do we need this check?
route = self.check_for_matching_routes(top_class)
if route is None:
return RouteChoice()
threshold = (
route.score_threshold
if route.score_threshold is not None
else self.score_threshold
)
passed = self._pass_threshold(top_class_scores, threshold)
if passed:
vector = self._encode(text=text)

route, top_class_scores = self._retrieve_top_route(vector)
passed = self._check_threshold(top_class_scores, route)

if passed and route is not None and not simulate_static:
if route.function_schema and text is None:
raise ValueError(
"Route has a function schema, but no text was provided."
Expand All @@ -260,10 +249,45 @@ def __call__(
else:
route.llm = self.llm
return route(text)
elif passed and route is not None and simulate_static:
return RouteChoice(
name=route.name,
function_call=None,
similarity_score=None,
trigger=None,
)
else:
# if no route passes threshold, return empty route choice
return RouteChoice()

def _retrieve_top_route(
self, vector: List[float]
) -> Tuple[Optional[Route], List[float]]:
"""
Retrieve the top matching route based on the given vector.
Returns a tuple of the route (if any) and the scores of the top class.
"""
# get relevant results (scores and routes)
results = self._retrieve(xq=np.array(vector))
# decide most relevant routes
top_class, top_class_scores = self._semantic_classify(results)
# TODO do we need this check?
route = self.check_for_matching_routes(top_class)
return route, top_class_scores

def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool:
"""
Check if the route's score passes the specified threshold.
"""
if route is None:
return False
threshold = (
route.score_threshold
if route.score_threshold is not None
else self.score_threshold
)
return self._pass_threshold(scores, threshold)

def __str__(self):
return (
f"RouteLayer(encoder={self.encoder}, "
Expand Down Expand Up @@ -481,7 +505,8 @@ def _vec_evaluate(self, Xq: Union[List[float], Any], y: List[str]) -> float:
"""
correct = 0
for xq, target_route in zip(Xq, y):
route_choice = self(vector=xq)
# We treate dynamic routes as static here, because when evaluating we use only vectors, and dynamic routes expect strings by default.
route_choice = self(vector=xq, simulate_static=True)
if route_choice.name == target_route:
correct += 1
accuracy = correct / len(Xq)
Expand Down

0 comments on commit fcf7077

Please sign in to comment.