diff --git a/redisvl/extensions/llmcache/base.py b/redisvl/extensions/llmcache/base.py index 7505352c..c3a1b269 100644 --- a/redisvl/extensions/llmcache/base.py +++ b/redisvl/extensions/llmcache/base.py @@ -31,7 +31,11 @@ def set_ttl(self, ttl: Optional[int] = None): self._ttl = None def clear(self) -> None: - """Clear the LLMCache of all keys in the index.""" + """Clear the cache of all keys in the index.""" + raise NotImplementedError + + async def aclear(self) -> None: + """Async clear the cache of all keys in the index.""" raise NotImplementedError def check( @@ -41,6 +45,17 @@ def check( num_results: int = 1, return_fields: Optional[List[str]] = None, ) -> List[dict]: + """Check the cache based on a prompt or vector.""" + raise NotImplementedError + + async def acheck( + self, + prompt: Optional[str] = None, + vector: Optional[List[float]] = None, + num_results: int = 1, + return_fields: Optional[List[str]] = None, + ) -> List[dict]: + """Async check the cache based on a prompt or vector.""" raise NotImplementedError def store( @@ -50,7 +65,18 @@ def store( vector: Optional[List[float]] = None, metadata: Optional[dict] = {}, ) -> str: - """Stores the specified key-value pair in the cache along with + """Store the specified key-value pair in the cache along with + metadata.""" + raise NotImplementedError + + async def astore( + self, + prompt: str, + response: str, + vector: Optional[List[float]] = None, + metadata: Optional[dict] = {}, + ) -> str: + """Async store the specified key-value pair in the cache along with metadata.""" raise NotImplementedError diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 17856196..d991287c 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -8,7 +8,7 @@ CacheHit, SemanticCacheIndexSchema, ) -from redisvl.index import SearchIndex +from redisvl.index import AsyncSearchIndex, SearchIndex from redisvl.query import RangeQuery from redisvl.query.filter import FilterExpression from redisvl.utils.utils import current_timestamp, serialize, validate_vector_dims @@ -27,6 +27,9 @@ class SemanticCache(BaseLLMCache): updated_at_field_name: str = "updated_at" metadata_field_name: str = "metadata" + _index: SearchIndex + _aindex: Optional[AsyncSearchIndex] = None + def __init__( self, name: str = "llmcache", @@ -69,6 +72,12 @@ def __init__( """ super().__init__(ttl) + self.redis_kwargs = { + "redis_client": redis_client, + "redis_url": redis_url, + "connection_kwargs": connection_kwargs, + } + # Use the index name as the key prefix by default if "prefix" in kwargs: prefix = kwargs["prefix"] @@ -81,7 +90,8 @@ def __init__( model="sentence-transformers/all-mpnet-base-v2" ) - # Process fields + # Process fields and other settings + self.set_threshold(distance_threshold) self.return_fields = [ self.entry_id_field_name, self.prompt_field_name, @@ -94,7 +104,6 @@ def __init__( # Create semantic cache schema and index schema = SemanticCacheIndexSchema.from_params(name, prefix, vectorizer.dims) schema = self._modify_schema(schema, filterable_fields) - self._index = SearchIndex(schema=schema) # Handle redis connection @@ -114,13 +123,19 @@ def __init__( "If you wish to overwrite the index schema, set overwrite=True during initialization." ) - # Initialize other components - self._set_vectorizer(vectorizer) - self.set_threshold(distance_threshold) - - # Create the index + # Create the search index self._index.create(overwrite=overwrite, drop=False) + # Initialize and validate vectorizer + if not isinstance(vectorizer, BaseVectorizer): + raise TypeError("Must provide a valid redisvl.vectorizer class.") + + validate_vector_dims( + vectorizer.dims, + self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore + ) + self._vectorizer = vectorizer + def _modify_schema( self, schema: SemanticCacheIndexSchema, @@ -145,6 +160,21 @@ def _modify_schema( return schema + async def _get_async_index(self) -> AsyncSearchIndex: + """Lazily construct the async search index class.""" + if not self._aindex: + # Construct async index if necessary + self._aindex = AsyncSearchIndex(schema=self._index.schema) + # Connect Redis async client + redis_client = self.redis_kwargs["redis_client"] + redis_url = self.redis_kwargs["redis_url"] + connection_kwargs = self.redis_kwargs["connection_kwargs"] + if redis_client is not None: + await self._aindex.set_client(redis_client) + elif redis_url: + await self._aindex.connect(redis_url, **connection_kwargs) # type: ignore + return self._aindex + @property def index(self) -> SearchIndex: """The underlying SearchIndex for the cache. @@ -179,36 +209,25 @@ def set_threshold(self, distance_threshold: float) -> None: ) self._distance_threshold = float(distance_threshold) - def _set_vectorizer(self, vectorizer: BaseVectorizer) -> None: - """Sets the vectorizer for the LLM cache. - - Must be a valid subclass of BaseVectorizer and have equivalent - dimensions to the vector field defined in the schema. - - Args: - vectorizer (BaseVectorizer): The RedisVL vectorizer to use for - vectorizing cache entries. - - Raises: - TypeError: If the vectorizer is not a valid type. - ValueError: If the vector dimensions are mismatched. - """ - if not isinstance(vectorizer, BaseVectorizer): - raise TypeError("Must provide a valid redisvl.vectorizer class.") - - schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore - validate_vector_dims(vectorizer.dims, schema_vector_dims) - self._vectorizer = vectorizer - def clear(self) -> None: """Clear the cache of all keys while preserving the index.""" self._index.clear() + async def aclear(self) -> None: + """""" + aindex = await self._get_async_index() + await aindex.clear() + def delete(self) -> None: """Clear the semantic cache of all keys and remove the underlying search index.""" self._index.delete(drop=True) + async def adelete(self) -> None: + """""" + aindex = await self._get_async_index() + await aindex.delete(drop=True) + def drop( self, ids: Optional[List[str]] = None, keys: Optional[List[str]] = None ) -> None: @@ -224,11 +243,34 @@ def drop( if keys is not None: self._index.drop_keys(keys) + async def adrop( + self, ids: Optional[List[str]] = None, keys: Optional[List[str]] = None + ) -> None: + """Async expire specific entries from the cache by id or specific + Redis key. + + Args: + ids (Optional[str]): The document ID or IDs to remove from the cache. + keys (Optional[str]): The Redis keys to remove from the cache. + """ + aindex = await self._get_async_index() + + if ids is not None: + await aindex.drop_keys([self._index.key(id) for id in ids]) + if keys is not None: + await aindex.drop_keys(keys) + def _refresh_ttl(self, key: str) -> None: """Refresh the time-to-live for the specified key.""" if self._ttl: self._index.client.expire(key, self._ttl) # type: ignore + async def _async_refresh_ttl(self, key: str) -> None: + """Async refresh the time-to-live for the specified key.""" + aindex = await self._get_async_index() + if self._ttl: + await aindex.client.expire(key, self._ttl) # type: ignore + def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: """Converts a text prompt to its vector representation using the configured vectorizer.""" @@ -237,6 +279,14 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: return self._vectorizer.embed(prompt) + async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: + """Converts a text prompt to its vector representation using the + configured vectorizer.""" + if not isinstance(prompt, str): + raise TypeError("Prompt must be a string.") + + return await self._vectorizer.aembed(prompt) + def _check_vector_dims(self, vector: List[float]): """Checks the size of the provided vector and raises an error if it doesn't match the search index vector dimensions.""" @@ -333,6 +383,98 @@ def check( return cache_hits + async def acheck( + self, + prompt: Optional[str] = None, + vector: Optional[List[float]] = None, + num_results: int = 1, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[FilterExpression] = None, + distance_threshold: Optional[float] = None, + ) -> List[Dict[str, Any]]: + """Async check the semantic cache for results similar to the specified prompt + or vector. + + This method searches the cache using vector similarity with + either a raw text prompt (converted to a vector) or a provided vector as + input. It checks for semantically similar prompts and fetches the cached + LLM responses. + + Args: + prompt (Optional[str], optional): The text prompt to search for in + the cache. + vector (Optional[List[float]], optional): The vector representation + of the prompt to search for in the cache. + num_results (int, optional): The number of cached results to return. + Defaults to 1. + return_fields (Optional[List[str]], optional): The fields to include + in each returned result. If None, defaults to all available + fields in the cached entry. + filter_expression (Optional[FilterExpression]) : Optional filter expression + that can be used to filter cache results. Defaults to None and + the full cache will be searched. + distance_threshold (Optional[float]): The threshold for semantic + vector distance. + + Returns: + List[Dict[str, Any]]: A list of dicts containing the requested + return fields for each similar cached response. + + Raises: + ValueError: If neither a `prompt` nor a `vector` is specified. + ValueError: if 'vector' has incorrect dimensions. + TypeError: If `return_fields` is not a list when provided. + + .. code-block:: python + + response = await cache.acheck( + prompt="What is the captial city of France?" + ) + """ + aindex = await self._get_async_index() + + if not (prompt or vector): + raise ValueError("Either prompt or vector must be specified.") + + # overrides + distance_threshold = distance_threshold or self._distance_threshold + return_fields = return_fields or self.return_fields + vector = vector or await self._avectorize_prompt(prompt) + + self._check_vector_dims(vector) + + if not isinstance(return_fields, list): + raise TypeError("return_fields must be a list of field names") + + query = RangeQuery( + vector=vector, + vector_field_name=self.vector_field_name, + return_fields=self.return_fields, + distance_threshold=distance_threshold, + num_results=num_results, + return_score=True, + filter_expression=filter_expression, + ) + + cache_hits: List[Dict[Any, str]] = [] + + # Search the cache! + cache_search_results = await aindex.query(query) + + for cache_search_result in cache_search_results: + key = cache_search_result["id"] + await self._async_refresh_ttl(key) + + # Create cache hit + cache_hit = CacheHit(**cache_search_result) + cache_hit_dict = { + k: v for k, v in cache_hit.to_dict().items() if k in return_fields + } + cache_hit_dict["key"] = key + cache_hits.append(cache_hit_dict) + + return cache_hits + def store( self, prompt: str, @@ -392,6 +534,67 @@ def store( ) return keys[0] + async def astore( + self, + prompt: str, + response: str, + vector: Optional[List[float]] = None, + metadata: Optional[Dict[str, Any]] = None, + filters: Optional[Dict[str, Any]] = None, + ) -> str: + """Async stores the specified key-value pair in the cache along with metadata. + + Args: + prompt (str): The user prompt to cache. + response (str): The LLM response to cache. + vector (Optional[List[float]], optional): The prompt vector to + cache. Defaults to None, and the prompt vector is generated on + demand. + metadata (Optional[Dict[str, Any]], optional): The optional metadata to cache + alongside the prompt and response. Defaults to None. + filters (Optional[Dict[str, Any]]): The optional tag to assign to the cache entry. + Defaults to None. + + Returns: + str: The Redis key for the entries added to the semantic cache. + + Raises: + ValueError: If neither prompt nor vector is specified. + ValueError: if vector has incorrect dimensions. + TypeError: If provided metadata is not a dictionary. + + .. code-block:: python + + key = await cache.astore( + prompt="What is the captial city of France?", + response="Paris", + metadata={"city": "Paris", "country": "France"} + ) + """ + aindex = await self._get_async_index() + + # Vectorize prompt if necessary and create cache payload + vector = vector or self._vectorize_prompt(prompt) + + self._check_vector_dims(vector) + + # Build cache entry for the cache + cache_entry = CacheEntry( + prompt=prompt, + response=response, + prompt_vector=vector, + metadata=metadata, + filters=filters, + ) + + # Load cache entry with TTL + keys = await aindex.load( + data=[cache_entry.to_dict()], + ttl=self._ttl, + id_field=self.entry_id_field_name, + ) + return keys[0] + def update(self, key: str, **kwargs) -> None: """Update specific fields within an existing cache entry. If no fields are passed, then only the document TTL is refreshed. @@ -431,3 +634,48 @@ def update(self, key: str, **kwargs) -> None: self._index.client.hset(key, mapping=kwargs) # type: ignore self._refresh_ttl(key) + + async def aupdate(self, key: str, **kwargs) -> None: + """Async update specific fields within an existing cache entry. If no fields + are passed, then only the document TTL is refreshed. + + Args: + key (str): the key of the document to update using kwargs. + + Raises: + ValueError if an incorrect mapping is provided as a kwarg. + TypeError if metadata is provided and not of type dict. + + .. code-block:: python + + key = await cache.astore('this is a prompt', 'this is a response') + await cache.aupdate( + key, + metadata={"hit_count": 1, "model_name": "Llama-2-7b"} + ) + """ + aindex = await self._get_async_index() + + if kwargs: + for k, v in kwargs.items(): + + # Make sure the item is in the index schema + if k not in set( + self._index.schema.field_names + [self.metadata_field_name] + ): + raise ValueError(f"{k} is not a valid field within the cache entry") + + # Check for metadata and deserialize + if k == self.metadata_field_name: + if isinstance(v, dict): + kwargs[k] = serialize(v) + else: + raise TypeError( + "If specified, cached metadata must be a dictionary." + ) + + kwargs.update({self.updated_at_field_name: current_timestamp()}) + + await aindex.load(data=[kwargs], keys=[key]) + + await self._async_refresh_ttl(key) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index f5e6b4a6..e0e22366 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -1,4 +1,7 @@ +import asyncio +import atexit import json +import threading from functools import wraps from typing import ( TYPE_CHECKING, @@ -421,7 +424,6 @@ def set_client(self, redis_client: redis.Redis, **kwargs): raise TypeError("Invalid Redis client instance") self._redis_client = redis_client - return self def create(self, overwrite: bool = False, drop: bool = False) -> None: @@ -813,6 +815,31 @@ def __init__( "Must use set_client() or connect() methods to provide a Redis connection to AsyncSearchIndex" ) + atexit.register(self._cleanup_connection) + + def _cleanup_connection(self): + if self._redis_client: + + def run_in_thread(): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._redis_client.aclose()) + loop.close() + except RuntimeError: + pass + + # Run cleanup in a background thread to avoid event loop issues + thread = threading.Thread(target=run_in_thread) + thread.start() + thread.join() + + self._redis_client = None + + def disconnect(self): + """Disconnect and cleanup the underlying async redis connection.""" + self._cleanup_connection() + @classmethod async def from_existing( cls, @@ -902,10 +929,13 @@ async def set_client(self, redis_client: aredis.Redis): await index.set_client(client) """ - if not isinstance(redis_client, aredis.Redis): - raise TypeError("Invalid Redis client instance") - - self._redis_client = redis_client + if isinstance(redis_client, redis.Redis): + print("Setting client and converting from async", flush=True) + self._redis_client = RedisConnectionFactory.sync_to_async_redis( + redis_client + ) + else: + self._redis_client = redis_client return self diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 9ccc87c8..21095cde 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -1,8 +1,12 @@ import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type from redis import Redis +from redis.asyncio import Connection as AsyncConnection +from redis.asyncio import ConnectionPool as AsyncConnectionPool from redis.asyncio import Redis as AsyncRedis +from redis.asyncio import SSLConnection as AsyncSSLConnection +from redis.connection import AbstractConnection, SSLConnection from redis.exceptions import ResponseError from redisvl.redis.constants import DEFAULT_REQUIRED_MODULES @@ -226,6 +230,22 @@ def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedi # fallback to env var REDIS_URL return AsyncRedis.from_url(get_address_from_env(), **kwargs) + @staticmethod + def sync_to_async_redis(redis_client: Redis) -> AsyncRedis: + # pick the right connection class + connection_class: Type[AbstractConnection] = ( + AsyncSSLConnection + if redis_client.connection_pool.connection_class == SSLConnection + else AsyncConnection + ) + # make async client + return AsyncRedis.from_pool( # type: ignore + AsyncConnectionPool( + connection_class=connection_class, + **redis_client.connection_pool.connection_kwargs, + ) + ) + @staticmethod def get_modules(client: Redis) -> Dict[str, Any]: return unpack_redis_modules(convert_bytes(client.module_list())) diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index 3ea2dccd..858e94b1 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -53,7 +53,6 @@ def embed( ) -> List[float]: raise NotImplementedError - @abstractmethod async def aembed_many( self, texts: List[str], @@ -62,9 +61,9 @@ async def aembed_many( as_buffer: bool = False, **kwargs, ) -> List[List[float]]: - raise NotImplementedError + # Fallback to standard embedding call if no async support + return self.embed_many(texts, preprocess, batch_size, as_buffer, **kwargs) - @abstractmethod async def aembed( self, text: str, @@ -72,7 +71,8 @@ async def aembed( as_buffer: bool = False, **kwargs, ) -> List[float]: - raise NotImplementedError + # Fallback to standard embedding call if no async support + return self.embed(text, preprocess, as_buffer, **kwargs) def batchify(self, seq: list, size: int, preprocess: Optional[Callable] = None): for pos in range(0, len(seq), size): diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py index 47275d40..462783a1 100644 --- a/redisvl/utils/vectorize/text/cohere.py +++ b/redisvl/utils/vectorize/text/cohere.py @@ -235,25 +235,6 @@ def embed_many( ] return embeddings - async def aembed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 1000, - as_buffer: bool = False, - **kwargs, - ) -> List[List[float]]: - raise NotImplementedError - - async def aembed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> List[float]: - raise NotImplementedError - @property def type(self) -> str: return "cohere" diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index d5e255c9..ab983ffe 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -144,25 +144,6 @@ def embed_many( ) return embeddings - async def aembed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 1000, - as_buffer: bool = False, - **kwargs, - ) -> List[List[float]]: - raise NotImplementedError - - async def aembed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> List[float]: - raise NotImplementedError - @property def type(self) -> str: return "hf" diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index b7248003..2ab9b83b 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -194,25 +194,6 @@ def embed( result = self._client.get_embeddings([text]) return self._process_embedding(result[0].values, as_buffer) - async def aembed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 1000, - as_buffer: bool = False, - **kwargs, - ) -> List[List[float]]: - raise NotImplementedError - - async def aembed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> List[float]: - raise NotImplementedError - @property def type(self) -> str: return "vertexai" diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index e7ba1c3b..a5c937b6 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -152,9 +152,7 @@ async def test_search_index_client(async_client, index_schema): async def test_search_index_set_client(async_client, client, async_index): await async_index.set_client(async_client) assert async_index.client == async_client - # should not be able to set the sync client here - with pytest.raises(TypeError): - await async_index.set_client(client) + await async_index.set_client(client) async_index.disconnect() assert async_index.client == None diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 2263b745..cbfa3e9c 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -6,7 +6,7 @@ from redis.exceptions import ConnectionError from redisvl.extensions.llmcache import SemanticCache -from redisvl.index.index import SearchIndex +from redisvl.index.index import AsyncSearchIndex, SearchIndex from redisvl.query.filter import Num, Tag, Text from redisvl.utils.vectorize import HFTextVectorizer @@ -89,6 +89,33 @@ def test_reset_ttl(cache): assert cache.ttl is None +def test_get_index(cache): + assert isinstance(cache.index, SearchIndex) + + +@pytest.mark.asyncio +async def test_get_async_index(cache): + aindex = await cache._get_async_index() + assert isinstance(aindex, AsyncSearchIndex) + + +@pytest.mark.asyncio +async def test_get_async_index_from_provided_client(cache_with_redis_client): + aindex = await cache_with_redis_client._get_async_index() + assert isinstance(aindex, AsyncSearchIndex) + + +def test_delete(cache_no_cleanup): + cache_no_cleanup.delete() + assert not cache_no_cleanup.index.exists() + + +@pytest.mark.asyncio +async def test_async_delete(cache_no_cleanup): + await cache_no_cleanup.adelete() + assert not cache_no_cleanup.index.exists() + + def test_store_and_check(cache, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -103,6 +130,21 @@ def test_store_and_check(cache, vectorizer): assert "metadata" not in check_result[0] +@pytest.mark.asyncio +async def test_async_store_and_check(cache, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + await cache.astore(prompt, response, vector=vector) + check_result = await cache.acheck(vector=vector, distance_threshold=0.4) + + assert len(check_result) == 1 + print(check_result, flush=True) + assert response == check_result[0]["response"] + assert "metadata" not in check_result[0] + + def test_return_fields(cache, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -140,6 +182,44 @@ def test_return_fields(cache, vectorizer): assert set(check_result[0].keys()) == set(fields) +@pytest.mark.asyncio +async def test_async_return_fields(cache, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + await cache.astore(prompt, response, vector=vector) + + # check default return fields + check_result = await cache.acheck(vector=vector) + assert set(check_result[0].keys()) == { + "key", + "entry_id", + "prompt", + "response", + "vector_distance", + "inserted_at", + "updated_at", + } + + # check specific return fields + fields = [ + "key", + "entry_id", + "prompt", + "response", + "vector_distance", + ] + check_result = await cache.acheck(vector=vector, return_fields=fields) + assert set(check_result[0].keys()) == set(fields) + + # check only some return fields + fields = ["inserted_at", "updated_at"] + check_result = await cache.acheck(vector=vector, return_fields=fields) + fields.append("key") + assert set(check_result[0].keys()) == set(fields) + + # Test clearing the cache def test_clear(cache, vectorizer): prompt = "This is a test prompt." @@ -153,6 +233,19 @@ def test_clear(cache, vectorizer): assert len(check_result) == 0 +@pytest.mark.asyncio +async def test_async_clear(cache, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + await cache.astore(prompt, response, vector=vector) + await cache.aclear() + check_result = await cache.acheck(vector=vector) + + assert len(check_result) == 0 + + # Test TTL functionality def test_ttl_expiration(cache_with_ttl, vectorizer): prompt = "This is a test prompt." @@ -166,6 +259,19 @@ def test_ttl_expiration(cache_with_ttl, vectorizer): assert len(check_result) == 0 +@pytest.mark.asyncio +async def test_async_ttl_expiration(cache_with_ttl, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + await cache_with_ttl.astore(prompt, response, vector=vector) + sleep(3) + + check_result = await cache_with_ttl.acheck(vector=vector) + assert len(check_result) == 0 + + def test_ttl_refresh(cache_with_ttl, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -180,6 +286,21 @@ def test_ttl_refresh(cache_with_ttl, vectorizer): assert len(check_result) == 1 +@pytest.mark.asyncio +async def test_async_ttl_refresh(cache_with_ttl, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + await cache_with_ttl.astore(prompt, response, vector=vector) + + for _ in range(3): + sleep(1) + check_result = await cache_with_ttl.acheck(vector=vector) + + assert len(check_result) == 1 + + # Test manual expiration of single document def test_drop_document(cache, vectorizer): prompt = "This is a test prompt." @@ -194,6 +315,20 @@ def test_drop_document(cache, vectorizer): assert len(recheck_result) == 0 +@pytest.mark.asyncio +async def test_async_drop_document(cache, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + await cache.astore(prompt, response, vector=vector) + check_result = await cache.acheck(vector=vector) + + await cache.adrop(ids=[check_result[0]["entry_id"]]) + recheck_result = await cache.acheck(vector=vector) + assert len(recheck_result) == 0 + + # Test manual expiration of multiple documents def test_drop_documents(cache, vectorizer): prompts = [ @@ -219,6 +354,31 @@ def test_drop_documents(cache, vectorizer): assert len(recheck_result) == 1 +@pytest.mark.asyncio +async def test_async_drop_documents(cache, vectorizer): + prompts = [ + "This is a test prompt.", + "This is also test prompt.", + "This is another test prompt.", + ] + responses = [ + "This is a test response.", + "This is also test response.", + "This is a another test response.", + ] + for prompt, response in zip(prompts, responses): + vector = vectorizer.embed(prompt) + await cache.astore(prompt, response, vector=vector) + + check_result = await cache.acheck(vector=vector, num_results=3) + print(check_result, flush=True) + ids = [r["entry_id"] for r in check_result[0:2]] # drop first 2 entries + await cache.adrop(ids=ids) + + recheck_result = await cache.acheck(vector=vector, num_results=3) + assert len(recheck_result) == 1 + + # Test updating document fields def test_updating_document(cache): prompt = "This is a test prompt." @@ -240,6 +400,27 @@ def test_updating_document(cache): assert updated_result[0]["updated_at"] > check_result[0]["updated_at"] +@pytest.mark.asyncio +async def test_async_updating_document(cache): + prompt = "This is a test prompt." + response = "This is a test response." + await cache.astore(prompt=prompt, response=response) + + check_result = await cache.acheck(prompt=prompt, return_fields=["updated_at"]) + key = check_result[0]["key"] + + sleep(1) + + metadata = {"foo": "bar"} + await cache.aupdate(key=key, metadata=metadata) + + updated_result = await cache.acheck( + prompt=prompt, return_fields=["updated_at", "metadata"] + ) + assert updated_result[0]["metadata"] == metadata + assert updated_result[0]["updated_at"] > check_result[0]["updated_at"] + + def test_ttl_expiration_after_update(cache_with_ttl, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -255,6 +436,22 @@ def test_ttl_expiration_after_update(cache_with_ttl, vectorizer): assert len(check_result) == 0 +@pytest.mark.asyncio +async def test_async_ttl_expiration_after_update(cache_with_ttl, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + cache_with_ttl.set_ttl(4) + + assert cache_with_ttl.ttl == 4 + + await cache_with_ttl.astore(prompt, response, vector=vector) + sleep(5) + + check_result = await cache_with_ttl.acheck(vector=vector) + assert len(check_result) == 0 + + # Test check behavior with no match def test_check_no_match(cache, vectorizer): vector = vectorizer.embed("Some random sentence.") @@ -270,6 +467,15 @@ def test_check_invalid_input(cache): cache.check(prompt="test", return_fields="bad value") +@pytest.mark.asyncio +async def test_async_check_invalid_input(cache): + with pytest.raises(ValueError): + await cache.acheck() + + with pytest.raises(TypeError): + await cache.acheck(prompt="test", return_fields="bad value") + + def test_bad_connection_info(vectorizer): with pytest.raises(ConnectionError): SemanticCache( @@ -357,10 +563,6 @@ def test_multiple_items(cache, vectorizer): assert "metadata" not in check_result[0] -def test_get_index(cache): - assert isinstance(cache.index, SearchIndex) - - def test_store_and_check_with_provided_client(cache_with_redis_client, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -375,9 +577,21 @@ def test_store_and_check_with_provided_client(cache_with_redis_client, vectorize assert "metadata" not in check_result[0] -def test_delete(cache_no_cleanup): - cache_no_cleanup.delete() - assert not cache_no_cleanup.index.exists() +@pytest.mark.asyncio +async def test_async_store_and_check_with_provided_client( + cache_with_redis_client, vectorizer +): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + await cache_with_redis_client.astore(prompt, response, vector=vector) + check_result = await cache_with_redis_client.acheck(vector=vector) + + assert len(check_result) == 1 + print(check_result, flush=True) + assert response == check_result[0]["response"] + assert "metadata" not in check_result[0] def test_vector_size(cache, vectorizer): diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index 56943447..20c2955d 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -464,8 +464,7 @@ def test_semantic_add_and_get_relevant(semantic_session): default_context = semantic_session.get_relevant("list of fruits and vegetables") assert len(default_context) == 5 # 2 pairs of prompt:response, and system assert default_context == semantic_session.get_relevant( - "list of fruits and vegetables", - distance_threshold=0.5 + "list of fruits and vegetables", distance_threshold=0.5 ) # test tool calls can also be returned