diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 3ef16206..432f0af9 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -220,29 +220,18 @@ def __call__( self, text: Optional[str] = None, vector: Optional[List[float]] = None, + simulate_static: bool = False, ) -> RouteChoice: # if no vector provided, encode text to get vector if vector is None: if text is None: raise ValueError("Either text or vector must be provided") - vector_arr = self._encode(text=text) - else: - vector_arr = np.array(vector) - # get relevant results (scores and routes) - results = self._retrieve(xq=vector_arr) - # decide most relevant routes - top_class, top_class_scores = self._semantic_classify(results) - # TODO do we need this check? - route = self.check_for_matching_routes(top_class) - if route is None: - return RouteChoice() - threshold = ( - route.score_threshold - if route.score_threshold is not None - else self.score_threshold - ) - passed = self._pass_threshold(top_class_scores, threshold) - if passed: + vector = self._encode(text=text) + + route, top_class_scores = self._retrieve_top_route(vector) + passed = self._check_threshold(top_class_scores, route) + + if passed and route is not None and not simulate_static: if route.function_schema and text is None: raise ValueError( "Route has a function schema, but no text was provided." @@ -260,10 +249,45 @@ def __call__( else: route.llm = self.llm return route(text) + elif passed and route is not None and simulate_static: + return RouteChoice( + name=route.name, + function_call=None, + similarity_score=None, + trigger=None, + ) else: # if no route passes threshold, return empty route choice return RouteChoice() + def _retrieve_top_route( + self, vector: List[float] + ) -> Tuple[Optional[Route], List[float]]: + """ + Retrieve the top matching route based on the given vector. + Returns a tuple of the route (if any) and the scores of the top class. + """ + # get relevant results (scores and routes) + results = self._retrieve(xq=np.array(vector)) + # decide most relevant routes + top_class, top_class_scores = self._semantic_classify(results) + # TODO do we need this check? + route = self.check_for_matching_routes(top_class) + return route, top_class_scores + + def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool: + """ + Check if the route's score passes the specified threshold. + """ + if route is None: + return False + threshold = ( + route.score_threshold + if route.score_threshold is not None + else self.score_threshold + ) + return self._pass_threshold(scores, threshold) + def __str__(self): return ( f"RouteLayer(encoder={self.encoder}, " @@ -481,7 +505,8 @@ def _vec_evaluate(self, Xq: Union[List[float], Any], y: List[str]) -> float: """ correct = 0 for xq, target_route in zip(Xq, y): - route_choice = self(vector=xq) + # We treate dynamic routes as static here, because when evaluating we use only vectors, and dynamic routes expect strings by default. + route_choice = self(vector=xq, simulate_static=True) if route_choice.name == target_route: correct += 1 accuracy = correct / len(Xq)