Skip to content

Commit

Permalink
fix formatting and mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerhutcherson committed Aug 1, 2024
1 parent 51b6af3 commit 704038b
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 34 deletions.
8 changes: 5 additions & 3 deletions redisvl/extensions/llmcache/schema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Any, Dict, List, Optional

from pydantic.v1 import BaseModel, Field, root_validator, validator

from redisvl.redis.utils import array_to_buffer, hashify
from redisvl.utils.utils import current_timestamp, deserialize, serialize
from redisvl.schema import IndexSchema
from redisvl.utils.utils import current_timestamp, deserialize, serialize


class CacheEntry(BaseModel):
entry_id: str
entry_id: Optional[str] = Field(default=None)
prompt: str
response: str
prompt_vector: List[float]
Expand Down Expand Up @@ -103,4 +105,4 @@ def from_params(cls, name: str, prefix: str, vector_dims: int):
},
},
],
)
)
25 changes: 19 additions & 6 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from redis import Redis

from redisvl.extensions.llmcache.base import BaseLLMCache
from redisvl.extensions.llmcache.schema import CacheEntry, CacheHit, SemanticCacheIndexSchema
from redisvl.extensions.llmcache.schema import (
CacheEntry,
CacheHit,
SemanticCacheIndexSchema,
)
from redisvl.index import SearchIndex
from redisvl.query import RangeQuery
from redisvl.query.filter import FilterExpression, Tag
Expand Down Expand Up @@ -92,8 +96,13 @@ def __init__(

if filterable_fields is not None:
for filter_field in filterable_fields:
if filter_field["name"] in self.return_fields or filter_field["name"] =="key":
raise ValueError(f'{filter_field["name"]} is a reserved field name for the semantic cache schema')
if (
filter_field["name"] in self.return_fields
or filter_field["name"] == "key"
):
raise ValueError(
f'{filter_field["name"]} is a reserved field name for the semantic cache schema'
)
schema.add_field(filter_field)
# Add to return fields too
self.return_fields.append(filter_field["name"])
Expand Down Expand Up @@ -285,7 +294,9 @@ def check(

# Create cache hit
cache_hit = CacheHit(**cache_search_result)
cache_hit_dict = {k: v for k, v in cache_hit.to_dict().items() if k in return_fields}
cache_hit_dict = {
k: v for k, v in cache_hit.to_dict().items() if k in return_fields
}
cache_hit_dict["key"] = key
cache_hits.append(cache_hit_dict)

Expand Down Expand Up @@ -370,7 +381,9 @@ def update(self, key: str, **kwargs) -> None:
for k, v in kwargs.items():

# Make sure the item is in the index schema
if k not in set(self._index.schema.field_names + [self.metadata_field_name]):
if k not in set(
self._index.schema.field_names + [self.metadata_field_name]
):
raise ValueError(f"{k} is not a valid field within the cache entry")

# Check for metadata and deserialize
Expand All @@ -384,6 +397,6 @@ def update(self, key: str, **kwargs) -> None:

kwargs.update({self.updated_at_field_name: current_timestamp()})

self._index.client.hset(key, mapping=kwargs) # type: ignore
self._index.client.hset(key, mapping=kwargs) # type: ignore

self._refresh_ttl(key)
48 changes: 34 additions & 14 deletions tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections import namedtuple
from time import sleep, time
from pydantic.v1 import ValidationError
import pytest

import pytest
from pydantic.v1 import ValidationError
from redis.exceptions import ConnectionError

from redisvl.extensions.llmcache import SemanticCache
Expand All @@ -24,13 +24,14 @@ def cache(vectorizer, redis_url):
yield cache_instance
cache_instance._index.delete(True) # Clean up index


@pytest.fixture
def cache_with_filters(vectorizer, redis_url):
cache_instance = SemanticCache(
vectorizer=vectorizer,
distance_threshold=0.2,
filterable_fields=[{"name": "label", "type": "tag"}],
redis_url=redis_url
redis_url=redis_url,
)
yield cache_instance
cache_instance._index.delete(True) # Clean up index
Expand Down Expand Up @@ -411,13 +412,17 @@ def test_cache_filtering(cache_with_filters):
cache_with_filters.store(prompt, response, filters={"label": tags[i]})

# test we can specify one specific tag
results = cache_with_filters.check("test prompt 1", filter_expression=filter_1, num_results=5)
results = cache_with_filters.check(
"test prompt 1", filter_expression=filter_1, num_results=5
)
assert len(results) == 1
assert results[0]["prompt"] == "test prompt 0"

# test we can pass a list of tags
combined_filter = filter_1 | filter_2 | filter_3
results = cache_with_filters.check("test prompt 1", filter_expression=combined_filter, num_results=5)
results = cache_with_filters.check(
"test prompt 1", filter_expression=combined_filter, num_results=5
)
assert len(results) == 3

# test that default tag param searches full cache
Expand All @@ -426,7 +431,9 @@ def test_cache_filtering(cache_with_filters):

# test no results are returned if we pass a nonexistant tag
bad_filter = Tag("label") == "bad tag"
results = cache_with_filters.check("test prompt 1", filter_expression=bad_filter, num_results=5)
results = cache_with_filters.check(
"test prompt 1", filter_expression=bad_filter, num_results=5
)
assert len(results) == 0


Expand All @@ -436,26 +443,35 @@ def test_cache_bad_filters(vectorizer, redis_url):
vectorizer=vectorizer,
distance_threshold=0.2,
# invalid field type
filterable_fields=[{"name": "label", "type": "tag"}, {"name": "test", "type": "nothing"}],
redis_url=redis_url
filterable_fields=[
{"name": "label", "type": "tag"},
{"name": "test", "type": "nothing"},
],
redis_url=redis_url,
)

with pytest.raises(ValueError):
cache_instance = SemanticCache(
vectorizer=vectorizer,
distance_threshold=0.2,
# duplicate field type
filterable_fields=[{"name": "label", "type": "tag"}, {"name": "label", "type": "tag"}],
redis_url=redis_url
filterable_fields=[
{"name": "label", "type": "tag"},
{"name": "label", "type": "tag"},
],
redis_url=redis_url,
)

with pytest.raises(ValueError):
cache_instance = SemanticCache(
vectorizer=vectorizer,
distance_threshold=0.2,
# reserved field name
filterable_fields=[{"name": "label", "type": "tag"}, {"name": "metadata", "type": "tag"}],
redis_url=redis_url
filterable_fields=[
{"name": "label", "type": "tag"},
{"name": "metadata", "type": "tag"},
],
redis_url=redis_url,
)


Expand All @@ -468,12 +484,16 @@ def test_complex_filters(cache_with_filters):

# test we can do range filters on inserted_at and updated_at fields
range_filter = Num("inserted_at") < current_timestamp
results = cache_with_filters.check("prompt 1", filter_expression=range_filter, num_results=5)
results = cache_with_filters.check(
"prompt 1", filter_expression=range_filter, num_results=5
)
assert len(results) == 2

# test we can combine range filters and text filters
prompt_filter = Text("prompt") % "*pt 1"
combined_filter = prompt_filter & range_filter

results = cache_with_filters.check("prompt 1", filter_expression=combined_filter, num_results=5)
results = cache_with_filters.check(
"prompt 1", filter_expression=combined_filter, num_results=5
)
assert len(results) == 1
31 changes: 20 additions & 11 deletions tests/unit/test_llmcache_schema.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,51 @@
import pytest
import json

import pytest
from pydantic.v1 import ValidationError

from redisvl.extensions.llmcache.schema import CacheEntry, CacheHit
from redisvl.redis.utils import hashify, array_to_buffer
from redisvl.redis.utils import array_to_buffer, hashify


def test_valid_cache_entry_creation():
entry = CacheEntry(
prompt="What is AI?",
response="AI is artificial intelligence.",
prompt_vector=[0.1, 0.2, 0.3]
prompt_vector=[0.1, 0.2, 0.3],
)
assert entry.entry_id == hashify("What is AI?")
assert entry.prompt == "What is AI?"
assert entry.response == "AI is artificial intelligence."
assert entry.prompt_vector == [0.1, 0.2, 0.3]


def test_cache_entry_with_given_entry_id():
entry = CacheEntry(
entry_id="custom_id",
prompt="What is AI?",
response="AI is artificial intelligence.",
prompt_vector=[0.1, 0.2, 0.3]
prompt_vector=[0.1, 0.2, 0.3],
)
assert entry.entry_id == "custom_id"


def test_cache_entry_with_invalid_metadata():
with pytest.raises(ValidationError):
CacheEntry(
prompt="What is AI?",
response="AI is artificial intelligence.",
prompt_vector=[0.1, 0.2, 0.3],
metadata="invalid_metadata"
metadata="invalid_metadata",
)


def test_cache_entry_to_dict():
entry = CacheEntry(
prompt="What is AI?",
response="AI is artificial intelligence.",
prompt_vector=[0.1, 0.2, 0.3],
metadata={"author": "John"},
filters={"category": "technology"}
filters={"category": "technology"},
)
result = entry.to_dict()
assert result["entry_id"] == hashify("What is AI?")
Expand All @@ -50,21 +54,23 @@ def test_cache_entry_to_dict():
assert result["category"] == "technology"
assert "filters" not in result


def test_valid_cache_hit_creation():
hit = CacheHit(
entry_id="entry_1",
prompt="What is AI?",
response="AI is artificial intelligence.",
vector_distance=0.1,
inserted_at=1625819123.123,
updated_at=1625819123.123
updated_at=1625819123.123,
)
assert hit.entry_id == "entry_1"
assert hit.prompt == "What is AI?"
assert hit.response == "AI is artificial intelligence."
assert hit.vector_distance == 0.1
assert hit.inserted_at == hit.updated_at == 1625819123.123


def test_cache_hit_with_serialized_metadata():
hit = CacheHit(
entry_id="entry_1",
Expand All @@ -73,10 +79,11 @@ def test_cache_hit_with_serialized_metadata():
vector_distance=0.1,
inserted_at=1625819123.123,
updated_at=1625819123.123,
metadata=json.dumps({"author": "John"})
metadata=json.dumps({"author": "John"}),
)
assert hit.metadata == {"author": "John"}


def test_cache_hit_to_dict():
hit = CacheHit(
entry_id="entry_1",
Expand All @@ -85,7 +92,7 @@ def test_cache_hit_to_dict():
vector_distance=0.1,
inserted_at=1625819123.123,
updated_at=1625819123.123,
filters={"category": "technology"}
filters={"category": "technology"},
)
result = hit.to_dict()
assert result["entry_id"] == "entry_1"
Expand All @@ -95,24 +102,26 @@ def test_cache_hit_to_dict():
assert result["category"] == "technology"
assert "filters" not in result


def test_cache_entry_with_empty_optional_fields():
entry = CacheEntry(
prompt="What is AI?",
response="AI is artificial intelligence.",
prompt_vector=[0.1, 0.2, 0.3]
prompt_vector=[0.1, 0.2, 0.3],
)
result = entry.to_dict()
assert "metadata" not in result
assert "filters" not in result


def test_cache_hit_with_empty_optional_fields():
hit = CacheHit(
entry_id="entry_1",
prompt="What is AI?",
response="AI is artificial intelligence.",
vector_distance=0.1,
inserted_at=1625819123.123,
updated_at=1625819123.123
updated_at=1625819123.123,
)
result = hit.to_dict()
assert "metadata" not in result
Expand Down

0 comments on commit 704038b

Please sign in to comment.