From 704038b74c0152d13c8e961ef9dcf8ad76854525 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 1 Aug 2024 00:08:42 -0400 Subject: [PATCH] fix formatting and mypy --- redisvl/extensions/llmcache/schema.py | 8 +++-- redisvl/extensions/llmcache/semantic.py | 25 +++++++++---- tests/integration/test_llmcache.py | 48 +++++++++++++++++-------- tests/unit/test_llmcache_schema.py | 31 ++++++++++------ 4 files changed, 78 insertions(+), 34 deletions(-) diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index 64ddeb21..60dcdc94 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -1,12 +1,14 @@ from typing import Any, Dict, List, Optional + from pydantic.v1 import BaseModel, Field, root_validator, validator + from redisvl.redis.utils import array_to_buffer, hashify -from redisvl.utils.utils import current_timestamp, deserialize, serialize from redisvl.schema import IndexSchema +from redisvl.utils.utils import current_timestamp, deserialize, serialize class CacheEntry(BaseModel): - entry_id: str + entry_id: Optional[str] = Field(default=None) prompt: str response: str prompt_vector: List[float] @@ -103,4 +105,4 @@ def from_params(cls, name: str, prefix: str, vector_dims: int): }, }, ], - ) \ No newline at end of file + ) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index f0b5b66b..b790e6d1 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -3,7 +3,11 @@ from redis import Redis from redisvl.extensions.llmcache.base import BaseLLMCache -from redisvl.extensions.llmcache.schema import CacheEntry, CacheHit, SemanticCacheIndexSchema +from redisvl.extensions.llmcache.schema import ( + CacheEntry, + CacheHit, + SemanticCacheIndexSchema, +) from redisvl.index import SearchIndex from redisvl.query import RangeQuery from redisvl.query.filter import FilterExpression, Tag @@ -92,8 +96,13 @@ def __init__( if filterable_fields is not None: for filter_field in filterable_fields: - if filter_field["name"] in self.return_fields or filter_field["name"] =="key": - raise ValueError(f'{filter_field["name"]} is a reserved field name for the semantic cache schema') + if ( + filter_field["name"] in self.return_fields + or filter_field["name"] == "key" + ): + raise ValueError( + f'{filter_field["name"]} is a reserved field name for the semantic cache schema' + ) schema.add_field(filter_field) # Add to return fields too self.return_fields.append(filter_field["name"]) @@ -285,7 +294,9 @@ def check( # Create cache hit cache_hit = CacheHit(**cache_search_result) - cache_hit_dict = {k: v for k, v in cache_hit.to_dict().items() if k in return_fields} + cache_hit_dict = { + k: v for k, v in cache_hit.to_dict().items() if k in return_fields + } cache_hit_dict["key"] = key cache_hits.append(cache_hit_dict) @@ -370,7 +381,9 @@ def update(self, key: str, **kwargs) -> None: for k, v in kwargs.items(): # Make sure the item is in the index schema - if k not in set(self._index.schema.field_names + [self.metadata_field_name]): + if k not in set( + self._index.schema.field_names + [self.metadata_field_name] + ): raise ValueError(f"{k} is not a valid field within the cache entry") # Check for metadata and deserialize @@ -384,6 +397,6 @@ def update(self, key: str, **kwargs) -> None: kwargs.update({self.updated_at_field_name: current_timestamp()}) - self._index.client.hset(key, mapping=kwargs) # type: ignore + self._index.client.hset(key, mapping=kwargs) # type: ignore self._refresh_ttl(key) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index ec1c7c15..34c15113 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -1,8 +1,8 @@ from collections import namedtuple from time import sleep, time -from pydantic.v1 import ValidationError -import pytest +import pytest +from pydantic.v1 import ValidationError from redis.exceptions import ConnectionError from redisvl.extensions.llmcache import SemanticCache @@ -24,13 +24,14 @@ def cache(vectorizer, redis_url): yield cache_instance cache_instance._index.delete(True) # Clean up index + @pytest.fixture def cache_with_filters(vectorizer, redis_url): cache_instance = SemanticCache( vectorizer=vectorizer, distance_threshold=0.2, filterable_fields=[{"name": "label", "type": "tag"}], - redis_url=redis_url + redis_url=redis_url, ) yield cache_instance cache_instance._index.delete(True) # Clean up index @@ -411,13 +412,17 @@ def test_cache_filtering(cache_with_filters): cache_with_filters.store(prompt, response, filters={"label": tags[i]}) # test we can specify one specific tag - results = cache_with_filters.check("test prompt 1", filter_expression=filter_1, num_results=5) + results = cache_with_filters.check( + "test prompt 1", filter_expression=filter_1, num_results=5 + ) assert len(results) == 1 assert results[0]["prompt"] == "test prompt 0" # test we can pass a list of tags combined_filter = filter_1 | filter_2 | filter_3 - results = cache_with_filters.check("test prompt 1", filter_expression=combined_filter, num_results=5) + results = cache_with_filters.check( + "test prompt 1", filter_expression=combined_filter, num_results=5 + ) assert len(results) == 3 # test that default tag param searches full cache @@ -426,7 +431,9 @@ def test_cache_filtering(cache_with_filters): # test no results are returned if we pass a nonexistant tag bad_filter = Tag("label") == "bad tag" - results = cache_with_filters.check("test prompt 1", filter_expression=bad_filter, num_results=5) + results = cache_with_filters.check( + "test prompt 1", filter_expression=bad_filter, num_results=5 + ) assert len(results) == 0 @@ -436,8 +443,11 @@ def test_cache_bad_filters(vectorizer, redis_url): vectorizer=vectorizer, distance_threshold=0.2, # invalid field type - filterable_fields=[{"name": "label", "type": "tag"}, {"name": "test", "type": "nothing"}], - redis_url=redis_url + filterable_fields=[ + {"name": "label", "type": "tag"}, + {"name": "test", "type": "nothing"}, + ], + redis_url=redis_url, ) with pytest.raises(ValueError): @@ -445,8 +455,11 @@ def test_cache_bad_filters(vectorizer, redis_url): vectorizer=vectorizer, distance_threshold=0.2, # duplicate field type - filterable_fields=[{"name": "label", "type": "tag"}, {"name": "label", "type": "tag"}], - redis_url=redis_url + filterable_fields=[ + {"name": "label", "type": "tag"}, + {"name": "label", "type": "tag"}, + ], + redis_url=redis_url, ) with pytest.raises(ValueError): @@ -454,8 +467,11 @@ def test_cache_bad_filters(vectorizer, redis_url): vectorizer=vectorizer, distance_threshold=0.2, # reserved field name - filterable_fields=[{"name": "label", "type": "tag"}, {"name": "metadata", "type": "tag"}], - redis_url=redis_url + filterable_fields=[ + {"name": "label", "type": "tag"}, + {"name": "metadata", "type": "tag"}, + ], + redis_url=redis_url, ) @@ -468,12 +484,16 @@ def test_complex_filters(cache_with_filters): # test we can do range filters on inserted_at and updated_at fields range_filter = Num("inserted_at") < current_timestamp - results = cache_with_filters.check("prompt 1", filter_expression=range_filter, num_results=5) + results = cache_with_filters.check( + "prompt 1", filter_expression=range_filter, num_results=5 + ) assert len(results) == 2 # test we can combine range filters and text filters prompt_filter = Text("prompt") % "*pt 1" combined_filter = prompt_filter & range_filter - results = cache_with_filters.check("prompt 1", filter_expression=combined_filter, num_results=5) + results = cache_with_filters.check( + "prompt 1", filter_expression=combined_filter, num_results=5 + ) assert len(results) == 1 diff --git a/tests/unit/test_llmcache_schema.py b/tests/unit/test_llmcache_schema.py index 11b8bb58..e3961e6b 100644 --- a/tests/unit/test_llmcache_schema.py +++ b/tests/unit/test_llmcache_schema.py @@ -1,47 +1,51 @@ -import pytest import json +import pytest from pydantic.v1 import ValidationError + from redisvl.extensions.llmcache.schema import CacheEntry, CacheHit -from redisvl.redis.utils import hashify, array_to_buffer +from redisvl.redis.utils import array_to_buffer, hashify def test_valid_cache_entry_creation(): entry = CacheEntry( prompt="What is AI?", response="AI is artificial intelligence.", - prompt_vector=[0.1, 0.2, 0.3] + prompt_vector=[0.1, 0.2, 0.3], ) assert entry.entry_id == hashify("What is AI?") assert entry.prompt == "What is AI?" assert entry.response == "AI is artificial intelligence." assert entry.prompt_vector == [0.1, 0.2, 0.3] + def test_cache_entry_with_given_entry_id(): entry = CacheEntry( entry_id="custom_id", prompt="What is AI?", response="AI is artificial intelligence.", - prompt_vector=[0.1, 0.2, 0.3] + prompt_vector=[0.1, 0.2, 0.3], ) assert entry.entry_id == "custom_id" + def test_cache_entry_with_invalid_metadata(): with pytest.raises(ValidationError): CacheEntry( prompt="What is AI?", response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], - metadata="invalid_metadata" + metadata="invalid_metadata", ) + def test_cache_entry_to_dict(): entry = CacheEntry( prompt="What is AI?", response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], metadata={"author": "John"}, - filters={"category": "technology"} + filters={"category": "technology"}, ) result = entry.to_dict() assert result["entry_id"] == hashify("What is AI?") @@ -50,6 +54,7 @@ def test_cache_entry_to_dict(): assert result["category"] == "technology" assert "filters" not in result + def test_valid_cache_hit_creation(): hit = CacheHit( entry_id="entry_1", @@ -57,7 +62,7 @@ def test_valid_cache_hit_creation(): response="AI is artificial intelligence.", vector_distance=0.1, inserted_at=1625819123.123, - updated_at=1625819123.123 + updated_at=1625819123.123, ) assert hit.entry_id == "entry_1" assert hit.prompt == "What is AI?" @@ -65,6 +70,7 @@ def test_valid_cache_hit_creation(): assert hit.vector_distance == 0.1 assert hit.inserted_at == hit.updated_at == 1625819123.123 + def test_cache_hit_with_serialized_metadata(): hit = CacheHit( entry_id="entry_1", @@ -73,10 +79,11 @@ def test_cache_hit_with_serialized_metadata(): vector_distance=0.1, inserted_at=1625819123.123, updated_at=1625819123.123, - metadata=json.dumps({"author": "John"}) + metadata=json.dumps({"author": "John"}), ) assert hit.metadata == {"author": "John"} + def test_cache_hit_to_dict(): hit = CacheHit( entry_id="entry_1", @@ -85,7 +92,7 @@ def test_cache_hit_to_dict(): vector_distance=0.1, inserted_at=1625819123.123, updated_at=1625819123.123, - filters={"category": "technology"} + filters={"category": "technology"}, ) result = hit.to_dict() assert result["entry_id"] == "entry_1" @@ -95,16 +102,18 @@ def test_cache_hit_to_dict(): assert result["category"] == "technology" assert "filters" not in result + def test_cache_entry_with_empty_optional_fields(): entry = CacheEntry( prompt="What is AI?", response="AI is artificial intelligence.", - prompt_vector=[0.1, 0.2, 0.3] + prompt_vector=[0.1, 0.2, 0.3], ) result = entry.to_dict() assert "metadata" not in result assert "filters" not in result + def test_cache_hit_with_empty_optional_fields(): hit = CacheHit( entry_id="entry_1", @@ -112,7 +121,7 @@ def test_cache_hit_with_empty_optional_fields(): response="AI is artificial intelligence.", vector_distance=0.1, inserted_at=1625819123.123, - updated_at=1625819123.123 + updated_at=1625819123.123, ) result = hit.to_dict() assert "metadata" not in result