From f927dd3d1b183fcc76ec7b10f14cdcdb45a3e7d8 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 12 Aug 2024 13:21:12 -0400 Subject: [PATCH] support dynamic distance threshold --- redisvl/extensions/llmcache/semantic.py | 10 ++++++++-- .../extensions/session_manager/semantic_session.py | 12 +++++++++--- tests/integration/test_llmcache.py | 2 +- tests/integration/test_session_manager.py | 4 ++++ 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 1b78ea9e..9845d925 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -233,6 +233,7 @@ def check( num_results: int = 1, return_fields: Optional[List[str]] = None, filter_expression: Optional[FilterExpression] = None, + distance_threshold: Optional[float] = None, ) -> List[Dict[str, Any]]: """Checks the semantic cache for results similar to the specified prompt or vector. @@ -255,6 +256,8 @@ def check( filter_expression (Optional[FilterExpression]) : Optional filter expression that can be used to filter cache results. Defaults to None and the full cache will be searched. + distance_threshold (Optional[float]): The threshold for semantic + vector distance. Returns: List[Dict[str, Any]]: A list of dicts containing the requested @@ -274,9 +277,12 @@ def check( if not (prompt or vector): raise ValueError("Either prompt or vector must be specified.") + # overrides + distance_threshold = distance_threshold or self._distance_threshold + return_fields = return_fields or self.return_fields vector = vector or self._vectorize_prompt(prompt) + self._check_vector_dims(vector) - return_fields = return_fields or self.return_fields if not isinstance(return_fields, list): raise TypeError("return_fields must be a list of field names") @@ -285,7 +291,7 @@ def check( vector=vector, vector_field_name=self.vector_field_name, return_fields=self.return_fields, - distance_threshold=self._distance_threshold, + distance_threshold=distance_threshold, num_results=num_results, return_score=True, filter_expression=filter_expression, diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 5ce0318d..773f3fc5 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -137,6 +137,7 @@ def get_relevant( fall_back: bool = False, session_tag: Optional[str] = None, raw: bool = False, + distance_threshold: Optional[float] = None, ) -> Union[List[str], List[Dict[str, str]]]: """Searches the chat history for information semantically related to the specified prompt. @@ -151,10 +152,12 @@ def get_relevant( as_text (bool): Whether to return the prompts and responses as text or as JSON top_k (int): The number of previous messages to return. Default is 5. - fall_back (bool): Whether to drop back to recent conversation history - if no relevant context is found. session_tag (Optional[str]): Tag to be added to entries to link to a specific session. Defaults to instance uuid. + distance_threshold (Optional[float]): The threshold for semantic + vector distance. + fall_back (bool): Whether to drop back to recent conversation history + if no relevant context is found. raw (bool): Whether to return the full Redis hash entry or just the message. @@ -169,6 +172,9 @@ def get_relevant( if top_k == 0: return [] + # override distance threshold + distance_threshold = distance_threshold or self._distance_threshold + return_fields = [ self.session_field_name, self.role_field_name, @@ -187,7 +193,7 @@ def get_relevant( vector=self._vectorizer.embed(prompt), vector_field_name=self.vector_field_name, return_fields=return_fields, - distance_threshold=self._distance_threshold, + distance_threshold=distance_threshold, num_results=top_k, return_score=True, filter_expression=session_filter, diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 03a6f0eb..23a41299 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -95,7 +95,7 @@ def test_store_and_check(cache, vectorizer): vector = vectorizer.embed(prompt) cache.store(prompt, response, vector=vector) - check_result = cache.check(vector=vector) + check_result = cache.check(vector=vector, distance_threshold=0.4) assert len(check_result) == 1 print(check_result, flush=True) diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index d21e6651..56943447 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -463,6 +463,10 @@ def test_semantic_add_and_get_relevant(semantic_session): semantic_session.set_distance_threshold(0.5) default_context = semantic_session.get_relevant("list of fruits and vegetables") assert len(default_context) == 5 # 2 pairs of prompt:response, and system + assert default_context == semantic_session.get_relevant( + "list of fruits and vegetables", + distance_threshold=0.5 + ) # test tool calls can also be returned context = semantic_session.get_relevant("winter sports like skiing")