From df41114d790e546bbd35ca660af8d1497973e77c Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 9 Sep 2024 10:06:04 -0700 Subject: [PATCH] Update cache check logic (#216) Refactors the cache check logic into a single helper method. Adjusts the cache check and processing logic to better handle `return_fields` configurations. --- redisvl/extensions/llmcache/semantic.py | 73 +++++++++++++------------ tests/integration/test_llmcache.py | 6 +- 2 files changed, 42 insertions(+), 37 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 238d6100..375d80c9 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Dict, List, Optional from redis import Redis @@ -341,8 +342,10 @@ def check( prompt="What is the captial city of France?" ) """ - if not (prompt or vector): + if not any([prompt, vector]): raise ValueError("Either prompt or vector must be specified.") + if return_fields and not isinstance(return_fields, list): + raise TypeError("Return fields must be a list of values.") # overrides distance_threshold = distance_threshold or self._distance_threshold @@ -359,25 +362,14 @@ def check( filter_expression=filter_expression, ) - cache_hits: List[Dict[Any, str]] = [] - # Search the cache! cache_search_results = self._index.query(query) - - for cache_search_result in cache_search_results: - redis_key = cache_search_result.pop("id") - self._refresh_ttl(redis_key) - - # Create and process cache hit - cache_hit = CacheHit(**cache_search_result) - cache_hit_dict = cache_hit.to_dict() - # Filter down to only selected return fields if needed - if isinstance(return_fields, list) and len(return_fields) > 0: - cache_hit_dict = { - k: v for k, v in cache_hit_dict.items() if k in return_fields - } - cache_hit_dict[self.redis_key_field_name] = redis_key - cache_hits.append(cache_hit_dict) + redis_keys, cache_hits = self._process_cache_results( + cache_search_results, return_fields # type: ignore + ) + # Extend TTL on keys + for key in redis_keys: + self._refresh_ttl(key) return cache_hits @@ -431,19 +423,16 @@ async def acheck( """ aindex = await self._get_async_index() - if not (prompt or vector): + if not any([prompt, vector]): raise ValueError("Either prompt or vector must be specified.") + if return_fields and not isinstance(return_fields, list): + raise TypeError("Return fields must be a list of values.") # overrides distance_threshold = distance_threshold or self._distance_threshold - return_fields = return_fields or self.return_fields vector = vector or await self._avectorize_prompt(prompt) - self._check_vector_dims(vector) - if not isinstance(return_fields, list): - raise TypeError("return_fields must be a list of field names") - query = RangeQuery( vector=vector, vector_field_name=self.vector_field_name, @@ -454,24 +443,36 @@ async def acheck( filter_expression=filter_expression, ) - cache_hits: List[Dict[Any, str]] = [] - # Search the cache! cache_search_results = await aindex.query(query) + redis_keys, cache_hits = self._process_cache_results( + cache_search_results, return_fields # type: ignore + ) + # Extend TTL on keys + asyncio.gather(*[self._async_refresh_ttl(key) for key in redis_keys]) - for cache_search_result in cache_search_results: - key = cache_search_result["id"] - await self._async_refresh_ttl(key) + return cache_hits - # Create cache hit + def _process_cache_results( + self, cache_search_results: List[Dict[str, Any]], return_fields: List[str] + ): + redis_keys: List[str] = [] + cache_hits: List[Dict[Any, str]] = [] + for cache_search_result in cache_search_results: + # Pop the redis key from the result + redis_key = cache_search_result.pop("id") + redis_keys.append(redis_key) + # Create and process cache hit cache_hit = CacheHit(**cache_search_result) - cache_hit_dict = { - k: v for k, v in cache_hit.to_dict().items() if k in return_fields - } - cache_hit_dict["key"] = key + cache_hit_dict = cache_hit.to_dict() + # Filter down to only selected return fields if needed + if isinstance(return_fields, list) and len(return_fields) > 0: + cache_hit_dict = { + k: v for k, v in cache_hit_dict.items() if k in return_fields + } + cache_hit_dict[self.redis_key_field_name] = redis_key cache_hits.append(cache_hit_dict) - - return cache_hits + return redis_keys, cache_hits def store( self, diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 72717398..a722af3f 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -1,3 +1,4 @@ +import asyncio from collections import namedtuple from time import sleep, time @@ -297,7 +298,7 @@ async def test_async_ttl_refresh(cache_with_ttl, vectorizer): await cache_with_ttl.astore(prompt, response, vector=vector) for _ in range(3): - sleep(1) + await asyncio.sleep(1) check_result = await cache_with_ttl.acheck(vector=vector) assert len(check_result) == 1 @@ -465,6 +466,9 @@ def test_check_invalid_input(cache): with pytest.raises(ValueError): cache.check() + with pytest.raises(TypeError): + cache.check(prompt="test", return_fields="bad value") + @pytest.mark.asyncio async def test_async_check_invalid_input(cache):