Skip to content

Commit

Permalink
formatting and linting
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerhutcherson committed Aug 30, 2024
1 parent 02fca2d commit 156a14b
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 16 deletions.
2 changes: 1 addition & 1 deletion redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def process_results(
unpack_json = (
(storage_type == StorageType.JSON)
and isinstance(query, FilterQuery)
and not query._return_fields
and not query._return_fields # type: ignore
)

# Process records
Expand Down
2 changes: 1 addition & 1 deletion redisvl/query/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
"FilterQuery",
"RangeQuery",
"VectorRangeQuery",
"CountQuery"
"CountQuery",
]
15 changes: 8 additions & 7 deletions redisvl/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __str__(self) -> str:

def _build_query_string(self) -> str:
"""Build the full Redis query string."""
pass
raise NotImplementedError("Must be implemented by subclasses")

def set_filter(self, filter_expression: Optional[FilterExpression] = None):
"""Set the filter expression for the query.
Expand All @@ -45,12 +45,13 @@ def set_filter(self, filter_expression: Optional[FilterExpression] = None):
elif isinstance(filter_expression, FilterExpression):
self._filter_expression = filter_expression
else:
raise TypeError("filter_expression must be of type FilterExpression or None")
raise TypeError(
"filter_expression must be of type FilterExpression or None"
)

# Reset the query string
self._query_string = self._build_query_string()


@property
def filter(self) -> FilterExpression:
"""The filter expression for the query."""
Expand Down Expand Up @@ -167,16 +168,16 @@ def _build_query_string(self) -> str:
return str(self._filter_expression)


class BaseVectorQuery(BaseQuery):
DTYPES: Dict[str, np.dtype] = {
class BaseVectorQuery:
DTYPES: Dict[str, Any] = {
"float32": np.float32,
"float64": np.float64,
}
DISTANCE_ID: str = "vector_distance"
VECTOR_PARAM: str = "vector"


class VectorQuery(BaseVectorQuery):
class VectorQuery(BaseVectorQuery, BaseQuery):
def __init__(
self,
vector: Union[List[float], bytes],
Expand Down Expand Up @@ -268,7 +269,7 @@ def params(self) -> Dict[str, Any]:
return {self.VECTOR_PARAM: vector}


class VectorRangeQuery(BaseVectorQuery):
class VectorRangeQuery(BaseVectorQuery, BaseQuery):
DISTANCE_THRESHOLD_PARAM: str = "distance_threshold"

def __init__(
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def vectorizer():
@pytest.fixture
def cache(vectorizer, redis_url):
cache_instance = SemanticCache(
vectorizer=vectorizer, distance_threshold=0.2, redis_url="redis://localhost:6379"
vectorizer=vectorizer,
distance_threshold=0.2,
redis_url="redis://localhost:6379",
)
yield cache_instance
cache_instance._index.delete(True) # Clean up index
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/test_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,7 @@ def test_semantic_add_and_get_relevant(semantic_session):
default_context = semantic_session.get_relevant("list of fruits and vegetables")
assert len(default_context) == 5 # 2 pairs of prompt:response, and system
assert default_context == semantic_session.get_relevant(
"list of fruits and vegetables",
distance_threshold=0.5
"list of fruits and vegetables", distance_threshold=0.5
)

# test tool calls can also be returned
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_query_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest

from redis.commands.search.query import Query
from redis.commands.search.result import Result

Expand Down Expand Up @@ -124,10 +123,11 @@ def test_vector_query():
["field1", "field2"],
dialect=3,
num_results=10,
in_order=True
in_order=True,
)
assert vector_query._in_order


def test_range_query():
# Create a filter expression
filter_expression = Tag("brand") == "Nike"
Expand Down Expand Up @@ -182,7 +182,7 @@ def test_range_query():
["field1"],
filter_expression,
num_results=10,
in_order=True
in_order=True,
)
assert range_query._in_order

Expand All @@ -193,7 +193,7 @@ def test_range_query():
CountQuery(),
FilterQuery(),
VectorQuery(vector=[1, 2, 3], vector_field_name="vector"),
RangeQuery(vector=[1, 2, 3], vector_field_name="vector")
RangeQuery(vector=[1, 2, 3], vector_field_name="vector"),
],
)
def test_query_modifiers(query):
Expand Down

0 comments on commit 156a14b

Please sign in to comment.