Skip to content

Commit

Permalink
Update cache check logic (#216)
Browse files Browse the repository at this point in the history
Refactors the cache check logic into a single helper method. Adjusts the
cache check and processing logic to better handle `return_fields`
configurations.
  • Loading branch information
tylerhutcherson authored Sep 9, 2024
1 parent 973d431 commit df41114
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 37 deletions.
73 changes: 37 additions & 36 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, Dict, List, Optional

from redis import Redis
Expand Down Expand Up @@ -341,8 +342,10 @@ def check(
prompt="What is the captial city of France?"
)
"""
if not (prompt or vector):
if not any([prompt, vector]):
raise ValueError("Either prompt or vector must be specified.")
if return_fields and not isinstance(return_fields, list):
raise TypeError("Return fields must be a list of values.")

# overrides
distance_threshold = distance_threshold or self._distance_threshold
Expand All @@ -359,25 +362,14 @@ def check(
filter_expression=filter_expression,
)

cache_hits: List[Dict[Any, str]] = []

# Search the cache!
cache_search_results = self._index.query(query)

for cache_search_result in cache_search_results:
redis_key = cache_search_result.pop("id")
self._refresh_ttl(redis_key)

# Create and process cache hit
cache_hit = CacheHit(**cache_search_result)
cache_hit_dict = cache_hit.to_dict()
# Filter down to only selected return fields if needed
if isinstance(return_fields, list) and len(return_fields) > 0:
cache_hit_dict = {
k: v for k, v in cache_hit_dict.items() if k in return_fields
}
cache_hit_dict[self.redis_key_field_name] = redis_key
cache_hits.append(cache_hit_dict)
redis_keys, cache_hits = self._process_cache_results(
cache_search_results, return_fields # type: ignore
)
# Extend TTL on keys
for key in redis_keys:
self._refresh_ttl(key)

return cache_hits

Expand Down Expand Up @@ -431,19 +423,16 @@ async def acheck(
"""
aindex = await self._get_async_index()

if not (prompt or vector):
if not any([prompt, vector]):
raise ValueError("Either prompt or vector must be specified.")
if return_fields and not isinstance(return_fields, list):
raise TypeError("Return fields must be a list of values.")

# overrides
distance_threshold = distance_threshold or self._distance_threshold
return_fields = return_fields or self.return_fields
vector = vector or await self._avectorize_prompt(prompt)

self._check_vector_dims(vector)

if not isinstance(return_fields, list):
raise TypeError("return_fields must be a list of field names")

query = RangeQuery(
vector=vector,
vector_field_name=self.vector_field_name,
Expand All @@ -454,24 +443,36 @@ async def acheck(
filter_expression=filter_expression,
)

cache_hits: List[Dict[Any, str]] = []

# Search the cache!
cache_search_results = await aindex.query(query)
redis_keys, cache_hits = self._process_cache_results(
cache_search_results, return_fields # type: ignore
)
# Extend TTL on keys
asyncio.gather(*[self._async_refresh_ttl(key) for key in redis_keys])

for cache_search_result in cache_search_results:
key = cache_search_result["id"]
await self._async_refresh_ttl(key)
return cache_hits

# Create cache hit
def _process_cache_results(
self, cache_search_results: List[Dict[str, Any]], return_fields: List[str]
):
redis_keys: List[str] = []
cache_hits: List[Dict[Any, str]] = []
for cache_search_result in cache_search_results:
# Pop the redis key from the result
redis_key = cache_search_result.pop("id")
redis_keys.append(redis_key)
# Create and process cache hit
cache_hit = CacheHit(**cache_search_result)
cache_hit_dict = {
k: v for k, v in cache_hit.to_dict().items() if k in return_fields
}
cache_hit_dict["key"] = key
cache_hit_dict = cache_hit.to_dict()
# Filter down to only selected return fields if needed
if isinstance(return_fields, list) and len(return_fields) > 0:
cache_hit_dict = {
k: v for k, v in cache_hit_dict.items() if k in return_fields
}
cache_hit_dict[self.redis_key_field_name] = redis_key
cache_hits.append(cache_hit_dict)

return cache_hits
return redis_keys, cache_hits

def store(
self,
Expand Down
6 changes: 5 additions & 1 deletion tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from collections import namedtuple
from time import sleep, time

Expand Down Expand Up @@ -297,7 +298,7 @@ async def test_async_ttl_refresh(cache_with_ttl, vectorizer):
await cache_with_ttl.astore(prompt, response, vector=vector)

for _ in range(3):
sleep(1)
await asyncio.sleep(1)
check_result = await cache_with_ttl.acheck(vector=vector)

assert len(check_result) == 1
Expand Down Expand Up @@ -465,6 +466,9 @@ def test_check_invalid_input(cache):
with pytest.raises(ValueError):
cache.check()

with pytest.raises(TypeError):
cache.check(prompt="test", return_fields="bad value")


@pytest.mark.asyncio
async def test_async_check_invalid_input(cache):
Expand Down

0 comments on commit df41114

Please sign in to comment.