From cd37576defa21f63676a02b9cae1bec5fe5dd539 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 23 Aug 2024 11:37:01 -0400 Subject: [PATCH 01/10] wip --- redisvl/query/query.py | 167 +++++++++++++++++------------------------ 1 file changed, 69 insertions(+), 98 deletions(-) diff --git a/redisvl/query/query.py b/redisvl/query/query.py index a1b3832b..3b96d941 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -1,32 +1,19 @@ from typing import Any, Dict, List, Optional, Union import numpy as np -from redis.commands.search.query import Query +from redis.commands.search.query import Query as RedisQuery from redisvl.query.filter import FilterExpression from redisvl.redis.utils import array_to_buffer -class BaseQuery: - def __init__( - self, - return_fields: Optional[List[str]] = None, - num_results: int = 10, - dialect: int = 2, - sort_by: Optional[str] = None, - in_order: bool = False, - ): - """Base query class used to subclass many query types.""" - self._return_fields = return_fields if return_fields is not None else [] - self._num_results = num_results - self._dialect = dialect - self._first = 0 - self._limit = num_results - self._sort_by = sort_by - self._in_order = in_order +class BaseQuery(RedisQuery): + """Base query class used to subclass many query types.""" + + _filter: Optional[FilterExpression] = None def __str__(self) -> str: - return " ".join([str(x) for x in self.query.get_args()]) + return " ".join([str(x) for x in self.get_args()]) def set_filter(self, filter_expression: Optional[FilterExpression] = None): """Set the filter expression for the query. @@ -56,25 +43,8 @@ def get_filter(self) -> FilterExpression: """ return self._filter - def set_paging(self, first: int, limit: int): - """Set the paging parameters for the query to limit the number of - results. - - Args: - first (int): The zero-indexed offset for which to fetch query results - limit (int): The max number of results to include including the offset - - Raises: - TypeError: If first or limit are NOT integers. - """ - if not isinstance(first, int) or not isinstance(limit, int): - raise TypeError("Paging params must both be integers") - - self._first = first - self._limit = limit - @property - def query(self) -> Query: + def query(self) -> "BaseQuery": raise NotImplementedError @property @@ -82,131 +52,132 @@ def params(self) -> Dict[str, Any]: return {} -class CountQuery(BaseQuery): +class FilterQuery(BaseQuery): def __init__( self, filter_expression: FilterExpression, + return_fields: Optional[List[str]] = None, + num_results: int = 10, dialect: int = 2, + # sort_by: Optional[str] = None, + # in_order: bool = False, params: Optional[Dict[str, Any]] = None, ): - """A query for a simple count operation provided some filter expression. + """A query for a running a filtered search with a filter expression. Args: - filter_expression (FilterExpression): The filter expression to query for. - params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None. + filter_expression (FilterExpression): The filter expression to + query for. + return_fields (Optional[List[str]], optional): The fields to return. + num_results (Optional[int], optional): The number of results to + return. Defaults to 10. + sort_by (Optional[str]): The field to order the results by. Defaults + to None. Results will be ordered by vector distance. + in_order (bool): Requires the terms in the field to have + the same order as the terms in the query filter, regardless of + the offsets between them. Defaults to False. + params (Optional[Dict[str, Any]], optional): The parameters for the + query. Defaults to None. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression .. code-block:: python - from redisvl.query import CountQuery + + from redisvl.query import FilterQuery from redisvl.query.filter import Tag t = Tag("brand") == "Nike" - query = CountQuery(filter_expression=t) + q = FilterQuery(return_fields=["brand", "price"], filter_expression=t) - count = index.query(query) """ - super().__init__(num_results=0, dialect=dialect) - self.set_filter(filter_expression) self._params = params or {} + self._filter = filter_expression + + super().__init__(str(self._filter)) + + self.return_fields(*return_fields).paging(0, num_results).dialect(dialect) + + # if sort_by: + # self.sort_by(sort_by) + + # if in_order: + # self.in_order() + @property - def query(self) -> Query: - """The loaded Redis-Py query. + def query(self) -> "FilterQuery": + """Return a Redis-Py Query object representing the query. Returns: redis.commands.search.query.Query: The Redis-Py query object. """ - base_query = str(self._filter) - query = Query(base_query).no_content().paging(0, 0).dialect(self._dialect) - return query + return self - @property - def params(self) -> Dict[str, Any]: - """The parameters for the query. - Returns: - Dict[str, Any]: The parameters for the query. - """ - return self._params -class FilterQuery(BaseQuery): + +class CountQuery(BaseQuery): def __init__( self, filter_expression: FilterExpression, - return_fields: Optional[List[str]] = None, - num_results: int = 10, dialect: int = 2, - sort_by: Optional[str] = None, - in_order: bool = False, params: Optional[Dict[str, Any]] = None, ): - """A query for a running a filtered search with a filter expression. + """A query for a simple count operation provided some filter expression. Args: - filter_expression (FilterExpression): The filter expression to - query for. - return_fields (Optional[List[str]], optional): The fields to return. - num_results (Optional[int], optional): The number of results to - return. Defaults to 10. - sort_by (Optional[str]): The field to order the results by. Defaults - to None. Results will be ordered by vector distance. - in_order (bool): Requires the terms in the field to have - the same order as the terms in the query filter, regardless of - the offsets between them. Defaults to False. - params (Optional[Dict[str, Any]], optional): The parameters for the - query. Defaults to None. + filter_expression (FilterExpression): The filter expression to query for. + params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression .. code-block:: python - - from redisvl.query import FilterQuery + from redisvl.query import CountQuery from redisvl.query.filter import Tag t = Tag("brand") == "Nike" - q = FilterQuery(return_fields=["brand", "price"], filter_expression=t) + query = CountQuery(filter_expression=t) + count = index.query(query) """ - super().__init__(return_fields, num_results, dialect, sort_by, in_order) - self.set_filter(filter_expression) self._params = params or {} + self._filter = filter_expression + + super().__init__(str(self._filter)) + self.no_content().paging(0, 0).dialect(dialect) @property - def query(self) -> Query: - """Return a Redis-Py Query object representing the query. + def query(self) -> "BaseQuery": + """The loaded Redis-Py query. Returns: redis.commands.search.query.Query: The Redis-Py query object. """ - base_query = str(self._filter) - query = ( - Query(base_query) - .return_fields(*self._return_fields) - .paging(self._first, self._limit) - .dialect(self._dialect) - ) - if self._sort_by: - query = query.sort_by(self._sort_by) + return self - if self._in_order: - query = query.in_order() + @property + def params(self) -> Dict[str, Any]: + """The parameters for the query. + + Returns: + Dict[str, Any]: The parameters for the query. + """ + return self._params - return query class BaseVectorQuery(BaseQuery): - DTYPES = { + DTYPES: Dict[str, np.dtype] = { "float32": np.float32, "float64": np.float64, } - DISTANCE_ID = "vector_distance" - VECTOR_PARAM = "vector" + DISTANCE_ID: str = "vector_distance" + VECTOR_PARAM: str = "vector" def __init__( self, From 6ad5d16966d4a6ecf22b90305dc4fe3f9277b574 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 23 Aug 2024 13:12:55 -0400 Subject: [PATCH 02/10] wip --- redisvl/query/query.py | 320 +++++++++++++++++------------------------ 1 file changed, 134 insertions(+), 186 deletions(-) diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 3b96d941..2a2fd02e 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -10,110 +10,110 @@ class BaseQuery(RedisQuery): """Base query class used to subclass many query types.""" - _filter: Optional[FilterExpression] = None - - def __str__(self) -> str: - return " ".join([str(x) for x in self.get_args()]) - - def set_filter(self, filter_expression: Optional[FilterExpression] = None): - """Set the filter expression for the query. + def __init__(self, query_string: str = "*", params: Optional[Dict[str, Any]] = None): + """ + Initialize the BaseQuery class. Args: - filter_expression (Optional[FilterExpression], optional): The filter - to apply to the query. - - Raises: - TypeError: If filter_expression is not of type redisvl.query.FilterExpression + query_string (str, optional): The query string to use. Defaults to '*'. + params (Optional[Dict[str, Any]], optional): Optional parameters for the query. """ - if filter_expression is None: - # Default filter to match everything - self._filter = FilterExpression("*") - elif isinstance(filter_expression, FilterExpression): - self._filter = filter_expression - else: - raise TypeError( - "filter_expression must be of type FilterExpression or None" - ) + super().__init__(query_string) + self._params: Dict[str, Any] = params if params else {} - def get_filter(self) -> FilterExpression: - """Get the filter expression for the query. - - Returns: - FilterExpression: The filter for the query. - """ - return self._filter + def __str__(self) -> str: + """Return the string representation of the query.""" + return " ".join([str(x) for x in self.get_args()]) @property def query(self) -> "BaseQuery": - raise NotImplementedError + """Return self as the query object.""" + return self @property def params(self) -> Dict[str, Any]: - return {} + """Return the query parameters.""" + return self._params class FilterQuery(BaseQuery): def __init__( self, - filter_expression: FilterExpression, + filter_expression: Optional[FilterExpression] = None, return_fields: Optional[List[str]] = None, num_results: int = 10, dialect: int = 2, - # sort_by: Optional[str] = None, - # in_order: bool = False, + sort_by: Optional[str] = None, + in_order: bool = False, params: Optional[Dict[str, Any]] = None, ): - """A query for a running a filtered search with a filter expression. + """A query for running a filtered search with a filter expression. Args: - filter_expression (FilterExpression): The filter expression to - query for. + filter_expression (Optional[FilterExpression]): The optional filter + expression to query with. Defaults to '*'. return_fields (Optional[List[str]], optional): The fields to return. - num_results (Optional[int], optional): The number of results to - return. Defaults to 10. - sort_by (Optional[str]): The field to order the results by. Defaults - to None. Results will be ordered by vector distance. - in_order (bool): Requires the terms in the field to have - the same order as the terms in the query filter, regardless of - the offsets between them. Defaults to False. - params (Optional[Dict[str, Any]], optional): The parameters for the - query. Defaults to None. + num_results (Optional[int], optional): The number of results to return. Defaults to 10. + dialect (int, optional): The query dialect. Defaults to 2. + sort_by (Optional[str], optional): The field to order the results by. Defaults to None. + in_order (bool, optional): Requires the terms in the field to have the same order as the terms in the query filter. Defaults to False. + params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression + """ + self._filter: FilterExpression = filter_expression if filter_expression else FilterExpression("*") + self._params: Dict[str, Any] = params if params else {} - .. code-block:: python + # Initialize the base query with the full query string constructed from the filter expression + query_string = self.build_query_string() + super().__init__(query_string, params) + # Handle query settings + if return_fields: + self.return_fields(*return_fields) + self.paging(0, num_results).dialect(dialect) - from redisvl.query import FilterQuery - from redisvl.query.filter import Tag + if sort_by: + self.sort_by(sort_by) - t = Tag("brand") == "Nike" - q = FilterQuery(return_fields=["brand", "price"], filter_expression=t) + if in_order: + self.in_order() - """ - self._params = params or {} - self._filter = filter_expression - - super().__init__(str(self._filter)) + def build_query_string(self) -> str: + """Build the full query string based on the filter and other components.""" + # Example logic to build the full query string from filter and other parts + # This can be customized in child classes for more complex queries + return str(self._filter) - self.return_fields(*return_fields).paging(0, num_results).dialect(dialect) + def set_filter(self, filter_expression: Optional[FilterExpression] = None): + """Set the filter expression for the query. - # if sort_by: - # self.sort_by(sort_by) + Args: + filter_expression (Optional[FilterExpression], optional): The filter to apply to the query. - # if in_order: - # self.in_order() + Raises: + TypeError: If filter_expression is not of type redisvl.query.FilterExpression + """ + if filter_expression is None: + # Default filter to match everything + self._filter = FilterExpression("*") + elif isinstance(filter_expression, FilterExpression): + self._filter = filter_expression + else: + raise TypeError("filter_expression must be of type FilterExpression or None") + # Rebuild the query string and reinitialize the base query + query_string = self.build_query_string() + super().__init__(query_string, self._params) - @property - def query(self) -> "FilterQuery": - """Return a Redis-Py Query object representing the query. + def get_filter(self) -> FilterExpression: + """Get the filter expression for the query. Returns: - redis.commands.search.query.Query: The Redis-Py query object. + FilterExpression: The filter for the query. """ - return self + return self._filter @@ -151,23 +151,6 @@ def __init__( super().__init__(str(self._filter)) self.no_content().paging(0, 0).dialect(dialect) - @property - def query(self) -> "BaseQuery": - """The loaded Redis-Py query. - - Returns: - redis.commands.search.query.Query: The Redis-Py query object. - """ - return self - - @property - def params(self) -> Dict[str, Any]: - """The parameters for the query. - - Returns: - Dict[str, Any]: The parameters for the query. - """ - return self._params @@ -179,28 +162,6 @@ class BaseVectorQuery(BaseQuery): DISTANCE_ID: str = "vector_distance" VECTOR_PARAM: str = "vector" - def __init__( - self, - vector: Union[List[float], bytes], - vector_field_name: str, - return_fields: Optional[List[str]] = None, - filter_expression: Optional[FilterExpression] = None, - dtype: str = "float32", - num_results: int = 10, - return_score: bool = True, - dialect: int = 2, - sort_by: Optional[str] = None, - in_order: bool = False, - ): - super().__init__(return_fields, num_results, dialect, sort_by, in_order) - self.set_filter(filter_expression) - self._vector = vector - self._field = vector_field_name - self._dtype = dtype.lower() - - if return_score: - self._return_fields.append(self.DISTANCE_ID) - class VectorQuery(BaseVectorQuery): def __init__( @@ -247,42 +208,34 @@ def __init__( Note: Learn more about vector queries in Redis: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search """ - super().__init__( - vector, - vector_field_name, - return_fields, - filter_expression, - dtype, - num_results, - return_score, - dialect, - sort_by, - in_order, - ) + self._filter: FilterExpression = filter_expression if filter_expression else FilterExpression("*") + self._vector = vector + self._vector_field_name = vector_field_name + self._dtype = dtype + self._return_score = return_score + self._num_results = num_results - @property - def query(self) -> Query: - """Return a Redis-Py Query object representing the query. + super().__init__(self.build_query_string()) - Returns: - redis.commands.search.query.Query: The Redis-Py query object. - """ - base_query = f"{str(self._filter)}=>[KNN {self._num_results} @{self._field} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]" - query = ( - Query(base_query) - .return_fields(*self._return_fields) - .paging(self._first, self._limit) - .dialect(self._dialect) - ) - if self._sort_by: - query = query.sort_by(self._sort_by) + # Handle query modifiers + if return_fields: + self.return_fields(*return_fields) + + self.paging(0, self._num_results) + self.dialect(dialect) + + if sort_by: + self.sort_by(sort_by) else: - query = query.sort_by(self.DISTANCE_ID) + self.sort_by(self.DISTANCE_ID) + + if in_order: + self.in_order() - if self._in_order: - query = query.in_order() + def build_query_string(self) -> str: + """Build the full query string for vector search with optional filtering.""" + return f"{str(self._filter)}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]" - return query @property def params(self) -> Dict[str, Any]: @@ -292,15 +245,15 @@ def params(self) -> Dict[str, Any]: Dict[str, Any]: The parameters for the query. """ if isinstance(self._vector, bytes): - vector_param = self._vector + vector = self._vector else: - vector_param = array_to_buffer(self._vector, dtype=self.DTYPES[self._dtype]) + vector = array_to_buffer(self._vector, dtype=self.DTYPES[self._dtype]) - return {self.VECTOR_PARAM: vector_param} + return {self.VECTOR_PARAM: vector} -class RangeQuery(BaseVectorQuery): - DISTANCE_THRESHOLD_PARAM = "distance_threshold" +class VectorRangeQuery(BaseVectorQuery): + DISTANCE_THRESHOLD_PARAM: str = "distance_threshold" def __init__( self, @@ -351,20 +304,44 @@ def __init__( Learn more about vector range queries: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query """ - super().__init__( - vector, - vector_field_name, - return_fields, - filter_expression, - dtype, - num_results, - return_score, - dialect, - sort_by, - in_order, - ) + self._filter: FilterExpression = filter_expression if filter_expression else FilterExpression("*") + self._vector = vector + self._vector_field_name = vector_field_name + self._dtype = dtype + self._return_score = return_score + self._num_results = num_results self.set_distance_threshold(distance_threshold) + super().__init__(self.build_query_string()) + + # Handle query modifiers + if return_fields: + self.return_fields(*return_fields) + + self.paging(0, self._num_results) + self.dialect(dialect) + + if sort_by: + self.sort_by(sort_by) + else: + self.sort_by(self.DISTANCE_ID) + + if in_order: + self.in_order() + + + def build_query_string(self) -> str: + """Build the full query string for vector range queries with optional filtering""" + base_query = f"@{self._field}:[VECTOR_RANGE ${self.DISTANCE_THRESHOLD_PARAM} ${self.VECTOR_PARAM}]" + _filter = str(self._filter) + if _filter != "*": + return ( + f"({base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}} {_filter})" + ) + else: + return f"{base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}}" + + def set_distance_threshold(self, distance_threshold: float): """Set the distance treshold for the query. @@ -384,40 +361,6 @@ def distance_threshold(self) -> float: """ return self._distance_threshold - @property - def query(self) -> Query: - """Return a Redis-Py Query object representing the query. - - Returns: - redis.commands.search.query.Query: The Redis-Py query object. - """ - base_query = f"@{self._field}:[VECTOR_RANGE ${self.DISTANCE_THRESHOLD_PARAM} ${self.VECTOR_PARAM}]" - - _filter = str(self._filter) - - if _filter != "*": - base_query = ( - f"({base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}} {_filter})" - ) - else: - base_query = f"{base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}}" - - query = ( - Query(base_query) - .return_fields(*self._return_fields) - .paging(self._first, self._limit) - .dialect(self._dialect) - ) - if self._sort_by: - query = query.sort_by(self._sort_by) - else: - query = query.sort_by(self.DISTANCE_ID) - - if self._in_order: - query = query.in_order() - - return query - @property def params(self) -> Dict[str, Any]: """Return the parameters for the query. @@ -434,3 +377,8 @@ def params(self) -> Dict[str, Any]: self.VECTOR_PARAM: vector_param, self.DISTANCE_THRESHOLD_PARAM: self._distance_threshold, } + + +class RangeQuery(VectorRangeQuery): + # for backwards compatibility + pass From b0b18950c1c0c590c4ddde68023b85939e45d46f Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 30 Aug 2024 13:58:42 -0400 Subject: [PATCH 03/10] update query classes and tests --- redisvl/extensions/llmcache/semantic.py | 22 ++-- redisvl/index/index.py | 4 +- redisvl/query/query.py | 161 ++++++++++++++---------- tests/integration/test_llmcache.py | 5 +- tests/unit/test_query_types.py | 71 +++++++---- 5 files changed, 156 insertions(+), 107 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 17856196..e7e0b290 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -296,14 +296,9 @@ def check( # overrides distance_threshold = distance_threshold or self._distance_threshold - return_fields = return_fields or self.return_fields vector = vector or self._vectorize_prompt(prompt) - self._check_vector_dims(vector) - if not isinstance(return_fields, list): - raise TypeError("return_fields must be a list of field names") - query = RangeQuery( vector=vector, vector_field_name=self.vector_field_name, @@ -320,15 +315,18 @@ def check( cache_search_results = self._index.query(query) for cache_search_result in cache_search_results: - key = cache_search_result["id"] - self._refresh_ttl(key) + redis_key = cache_search_result.pop("id") + self._refresh_ttl(redis_key) - # Create cache hit + # Create and process cache hit cache_hit = CacheHit(**cache_search_result) - cache_hit_dict = { - k: v for k, v in cache_hit.to_dict().items() if k in return_fields - } - cache_hit_dict["key"] = key + cache_hit_dict = cache_hit.to_dict() + # Filter down to only selected return fields if needed + if isinstance(return_fields, list) and len(return_fields) > 0: + cache_hit_dict = { + k: v for k, v in cache_hit_dict.items() if k in return_fields + } + cache_hit_dict[self.redis_key_field_name] = redis_key cache_hits.append(cache_hit_dict) return cache_hits diff --git a/redisvl/index/index.py b/redisvl/index/index.py index f5e6b4a6..4b2645a4 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -710,7 +710,7 @@ def paginate(self, query: BaseQuery, page_size: int = 30) -> Generator: offset = 0 while True: - query.set_paging(offset, page_size) + query.paging(offset, page_size) results = self._query(query) if not results: break @@ -1194,7 +1194,7 @@ async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerato first = 0 while True: - query.set_paging(first, page_size) + query.paging(first, page_size) results = await self._query(query) if not results: break diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 2a2fd02e..c360c275 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -10,21 +10,59 @@ class BaseQuery(RedisQuery): """Base query class used to subclass many query types.""" - def __init__(self, query_string: str = "*", params: Optional[Dict[str, Any]] = None): + _params: Dict[str, Any] = {} + _filter_expression: FilterExpression = FilterExpression("*") + + def __init__(self, query_string: str = "*"): """ Initialize the BaseQuery class. Args: query_string (str, optional): The query string to use. Defaults to '*'. - params (Optional[Dict[str, Any]], optional): Optional parameters for the query. """ super().__init__(query_string) - self._params: Dict[str, Any] = params if params else {} def __str__(self) -> str: """Return the string representation of the query.""" return " ".join([str(x) for x in self.get_args()]) + def _build_query_string(self) -> str: + """Build the full Redis query string.""" + pass + + def set_filter(self, filter_expression: Optional[FilterExpression] = None): + """Set the filter expression for the query. + + Args: + filter_expression (Optional[FilterExpression], optional): The filter to apply to the query. + + Raises: + TypeError: If filter_expression is not of type redisvl.query.FilterExpression + """ + if filter_expression is None: + # Default filter to match everything + self._filter_expression = FilterExpression("*") + elif isinstance(filter_expression, FilterExpression): + self._filter_expression = filter_expression + else: + raise TypeError("filter_expression must be of type FilterExpression or None") + + # Reset the query string + self._query_string = self._build_query_string() + + def get_filter(self) -> FilterExpression: + """Get the filter expression for the query. + + Returns: + FilterExpression: The filter for the query. + """ + return self.filter + + @property + def filter(self) -> FilterExpression: + """The filter expression for the query.""" + return self._filter_expression + @property def query(self) -> "BaseQuery": """Return self as the query object.""" @@ -62,17 +100,21 @@ def __init__( Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression """ - self._filter: FilterExpression = filter_expression if filter_expression else FilterExpression("*") - self._params: Dict[str, Any] = params if params else {} + if filter_expression: + self._filter_expression = filter_expression + if params: + self._params = params + + self._num_results = num_results # Initialize the base query with the full query string constructed from the filter expression - query_string = self.build_query_string() - super().__init__(query_string, params) + query_string = self._build_query_string() + super().__init__(query_string) # Handle query settings if return_fields: self.return_fields(*return_fields) - self.paging(0, num_results).dialect(dialect) + self.paging(0, self._num_results).dialect(dialect) if sort_by: self.sort_by(sort_by) @@ -80,56 +122,24 @@ def __init__( if in_order: self.in_order() - def build_query_string(self) -> str: + def _build_query_string(self) -> str: """Build the full query string based on the filter and other components.""" # Example logic to build the full query string from filter and other parts # This can be customized in child classes for more complex queries - return str(self._filter) - - def set_filter(self, filter_expression: Optional[FilterExpression] = None): - """Set the filter expression for the query. - - Args: - filter_expression (Optional[FilterExpression], optional): The filter to apply to the query. - - Raises: - TypeError: If filter_expression is not of type redisvl.query.FilterExpression - """ - if filter_expression is None: - # Default filter to match everything - self._filter = FilterExpression("*") - elif isinstance(filter_expression, FilterExpression): - self._filter = filter_expression - else: - raise TypeError("filter_expression must be of type FilterExpression or None") - - # Rebuild the query string and reinitialize the base query - query_string = self.build_query_string() - super().__init__(query_string, self._params) - - def get_filter(self) -> FilterExpression: - """Get the filter expression for the query. - - Returns: - FilterExpression: The filter for the query. - """ - return self._filter - - - + return str(self._filter_expression) class CountQuery(BaseQuery): def __init__( self, - filter_expression: FilterExpression, + filter_expression: Optional[FilterExpression] = None, dialect: int = 2, params: Optional[Dict[str, Any]] = None, ): """A query for a simple count operation provided some filter expression. Args: - filter_expression (FilterExpression): The filter expression to query for. + filter_expression (Optional[FilterExpression]): The filter expression to query with. Defaults to None. params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None. Raises: @@ -145,13 +155,23 @@ def __init__( count = index.query(query) """ - self._params = params or {} - self._filter = filter_expression + if filter_expression: + self._filter_expression = filter_expression + if params: + self._params = params - super().__init__(str(self._filter)) - self.no_content().paging(0, 0).dialect(dialect) + # Initialize the base query with the full query string constructed from the filter expression + query_string = self._build_query_string() + super().__init__(query_string) + # Query specific modifications + self.no_content().paging(0, 0).dialect(dialect) + def _build_query_string(self) -> str: + """Build the full query string based on the filter and other components.""" + # Example logic to build the full query string from filter and other parts + # This can be customized in child classes for more complex queries + return str(self._filter_expression) class BaseVectorQuery(BaseQuery): @@ -208,21 +228,25 @@ def __init__( Note: Learn more about vector queries in Redis: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search """ - self._filter: FilterExpression = filter_expression if filter_expression else FilterExpression("*") + if filter_expression: + self._filter_expression = filter_expression + self._vector = vector self._vector_field_name = vector_field_name self._dtype = dtype - self._return_score = return_score self._num_results = num_results - super().__init__(self.build_query_string()) + query_string = self._build_query_string() + super().__init__(query_string) # Handle query modifiers if return_fields: self.return_fields(*return_fields) - self.paging(0, self._num_results) - self.dialect(dialect) + self.paging(0, self._num_results).dialect(dialect) + + if return_score: + self.return_fields(self.DISTANCE_ID) if sort_by: self.sort_by(sort_by) @@ -232,10 +256,9 @@ def __init__( if in_order: self.in_order() - def build_query_string(self) -> str: + def _build_query_string(self) -> str: """Build the full query string for vector search with optional filtering.""" - return f"{str(self._filter)}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]" - + return f"{str(self._filter_expression)}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]" @property def params(self) -> Dict[str, Any]: @@ -304,22 +327,26 @@ def __init__( Learn more about vector range queries: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query """ - self._filter: FilterExpression = filter_expression if filter_expression else FilterExpression("*") + if filter_expression: + self._filter_expression = filter_expression + self._vector = vector self._vector_field_name = vector_field_name self._dtype = dtype - self._return_score = return_score self._num_results = num_results self.set_distance_threshold(distance_threshold) - super().__init__(self.build_query_string()) + query_string = self._build_query_string() + super().__init__(query_string) # Handle query modifiers if return_fields: self.return_fields(*return_fields) - self.paging(0, self._num_results) - self.dialect(dialect) + self.paging(0, self._num_results).dialect(dialect) + + if return_score: + self.return_fields(self.DISTANCE_ID) if sort_by: self.sort_by(sort_by) @@ -329,11 +356,10 @@ def __init__( if in_order: self.in_order() - - def build_query_string(self) -> str: + def _build_query_string(self) -> str: """Build the full query string for vector range queries with optional filtering""" - base_query = f"@{self._field}:[VECTOR_RANGE ${self.DISTANCE_THRESHOLD_PARAM} ${self.VECTOR_PARAM}]" - _filter = str(self._filter) + base_query = f"@{self._vector_field_name}:[VECTOR_RANGE ${self.DISTANCE_THRESHOLD_PARAM} ${self.VECTOR_PARAM}]" + _filter = str(self._filter_expression) if _filter != "*": return ( f"({base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}} {_filter})" @@ -341,9 +367,8 @@ def build_query_string(self) -> str: else: return f"{base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}}" - def set_distance_threshold(self, distance_threshold: float): - """Set the distance treshold for the query. + """Set the distance threshold for the query. Args: distance_threshold (float): vector distance @@ -380,5 +405,5 @@ def params(self) -> Dict[str, Any]: class RangeQuery(VectorRangeQuery): - # for backwards compatibility + # keep for backwards compatibility pass diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 2263b745..9252f2bd 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -19,7 +19,7 @@ def vectorizer(): @pytest.fixture def cache(vectorizer, redis_url): cache_instance = SemanticCache( - vectorizer=vectorizer, distance_threshold=0.2, redis_url=redis_url + vectorizer=vectorizer, distance_threshold=0.2, redis_url="redis://localhost:6379" ) yield cache_instance cache_instance._index.delete(True) # Clean up index @@ -266,9 +266,6 @@ def test_check_invalid_input(cache): with pytest.raises(ValueError): cache.check() - with pytest.raises(TypeError): - cache.check(prompt="test", return_fields="bad value") - def test_bad_connection_info(vectorizer): with pytest.raises(ConnectionError): diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py index e0fd4f17..c5bbe881 100644 --- a/tests/unit/test_query_types.py +++ b/tests/unit/test_query_types.py @@ -18,6 +18,7 @@ def test_count_query(): count_query = CountQuery(filter_expression) # Check properties + assert isinstance(count_query) assert isinstance(count_query.query, Query) assert isinstance(count_query.params, dict) assert count_query.params == {} @@ -41,11 +42,12 @@ def test_filter_query(): assert filter_query._return_fields == return_fields assert filter_query._num_results == 10 assert filter_query.get_filter() == filter_expression + assert isinstance(filter_query, Query) assert isinstance(filter_query.query, Query) assert isinstance(filter_query.params, dict) assert filter_query.params == {} assert filter_query._dialect == 2 - assert filter_query._sort_by == None + assert filter_query._sortby == None assert filter_query._in_order == False # Test set_filter functionality @@ -53,17 +55,17 @@ def test_filter_query(): filter_query.set_filter(new_filter_expression) assert filter_query.get_filter() == new_filter_expression - # Test set_paging functionality - filter_query.set_paging(5, 7) - assert filter_query._first == 5 - assert filter_query._limit == 7 + # Test paging functionality + filter_query.paging(5, 7) + assert filter_query._offset == 5 + assert filter_query._num == 7 assert filter_query._num_results == 10 # Test sort_by functionality filter_query = FilterQuery( filter_expression, return_fields, num_results=10, sort_by="price" ) - assert filter_query._sort_by == "price" + assert filter_query._sortby is not None # Test in_order functionality filter_query = FilterQuery( @@ -80,14 +82,15 @@ def test_vector_query(): # Check properties assert vector_query._vector == sample_vector - assert vector_query._field == "vector_field" + assert vector_query._vector_field_name == "vector_field" assert vector_query._num_results == 10 assert "field1" in vector_query._return_fields + assert isinstance(vector_query, Query) assert isinstance(vector_query.query, Query) assert isinstance(vector_query.params, dict) assert vector_query.params != {} assert vector_query._dialect == 3 - assert vector_query._sort_by == None + assert vector_query._sortby.args[0] == VectorQuery.DISTANCE_ID assert vector_query._in_order == False # Test set_filter functionality @@ -95,10 +98,10 @@ def test_vector_query(): vector_query.set_filter(new_filter_expression) assert vector_query.get_filter() == new_filter_expression - # Test set_paging functionality - vector_query.set_paging(5, 7) - assert vector_query._first == 5 - assert vector_query._limit == 7 + # Test paging functionality + vector_query.paging(5, 7) + assert vector_query._offset == 5 + assert vector_query._num == 7 assert vector_query._num_results == 10 # Test sort_by functionality @@ -110,8 +113,18 @@ def test_vector_query(): num_results=10, sort_by="field2", ) - assert vector_query._sort_by == "field2" + assert vector_query._sortby.args[0] == "field2" + # Test in_order functionality + vector_query = VectorQuery( + sample_vector, + "vector_field", + ["field1", "field2"], + dialect=3, + num_results=10, + in_order=True + ) + assert vector_query._in_order def test_range_query(): # Create a filter expression @@ -124,25 +137,29 @@ def test_range_query(): # Check properties assert range_query._vector == sample_vector - assert range_query._field == "vector_field" + assert range_query._vector_field_name == "vector_field" assert range_query._num_results == 10 assert range_query.distance_threshold == 0.2 assert "field1" in range_query._return_fields + assert isinstance(range_query, Query) assert isinstance(range_query.query, Query) assert isinstance(range_query.params, dict) assert range_query.params != {} - assert range_query._sort_by == None - assert range_query._sort_by == None + assert range_query._sortby.args[0] == RangeQuery.DISTANCE_ID + + # Test set_distance_threshold functionality + range_query.set_distance_threshold(0.1) + assert range_query.distance_threshold == 0.1 # Test set_filter functionality new_filter_expression = Tag("category") == "Outdoor" range_query.set_filter(new_filter_expression) assert range_query.get_filter() == new_filter_expression - # Test set_paging functionality - range_query.set_paging(5, 7) - assert range_query._first == 5 - assert range_query._limit == 7 + # Test paging functionality + range_query.paging(5, 7) + assert range_query._offset == 5 + assert range_query._num == 7 assert range_query._num_results == 10 # Test sort_by functionality @@ -154,4 +171,16 @@ def test_range_query(): num_results=10, sort_by="field1", ) - assert range_query._sort_by == "field1" + assert range_query._sortby.args[0] == "field1" + + # Test in_order functionality + range_query = RangeQuery( + sample_vector, + "vector_field", + ["field1"], + filter_expression, + num_results=10, + in_order=True + ) + assert range_query._in_order + From d999853cfc88df96089231c1fe512df9dbe62fe7 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 30 Aug 2024 16:12:31 -0400 Subject: [PATCH 04/10] fix test --- tests/unit/test_query_types.py | 72 +++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py index c5bbe881..05363ca7 100644 --- a/tests/unit/test_query_types.py +++ b/tests/unit/test_query_types.py @@ -1,3 +1,5 @@ +import pytest + from redis.commands.search.query import Query from redis.commands.search.result import Result @@ -18,7 +20,7 @@ def test_count_query(): count_query = CountQuery(filter_expression) # Check properties - assert isinstance(count_query) + assert isinstance(count_query, Query) assert isinstance(count_query.query, Query) assert isinstance(count_query.params, dict) assert count_query.params == {} @@ -84,7 +86,7 @@ def test_vector_query(): assert vector_query._vector == sample_vector assert vector_query._vector_field_name == "vector_field" assert vector_query._num_results == 10 - assert "field1" in vector_query._return_fields + assert vector_query._return_fields == ["field1", "field2", "vector_distance"] assert isinstance(vector_query, Query) assert isinstance(vector_query.query, Query) assert isinstance(vector_query.params, dict) @@ -184,3 +186,69 @@ def test_range_query(): ) assert range_query._in_order + +@pytest.mark.parametrize( + "query", + [ + CountQuery(), + FilterQuery(), + VectorQuery(vector=[1, 2, 3], vector_field_name="vector"), + RangeQuery(vector=[1, 2, 3], vector_field_name="vector") + ], +) +def test_query_modifiers(query): + query.paging(3, 5) + assert query._offset == 3 + assert query._num == 5 + + query.dialect(4) + assert query._dialect == 4 + + query.in_order() + assert query._in_order + + query.sort_by("time") + assert query._sortby.args[0] == "time" + + query.scorer("BM25") + assert query._scorer == "BM25" + + query.timeout(20) + assert query._timeout == 20 + + query.slop(10) + assert query._slop == 10 + + query.verbatim() + assert query._verbatim + + query.no_content() + assert query._no_content + + query.no_stopwords() + assert query._no_stopwords + + query.with_scores() + assert query._with_scores + + query.limit_fields("test") + assert query._fields == ("test",) + + f = Tag("test") == "foo" + query.set_filter(f) + assert query._filter_expression == f + + # double check all other states + assert query._offset == 3 + assert query._num == 5 + assert query._dialect == 4 + assert query._in_order + assert query._sortby.args[0] == "time" + assert query._scorer == "BM25" + assert query._timeout == 20 + assert query._slop == 10 + assert query._verbatim + assert query._no_content + assert query._no_stopwords + assert query._with_scores + assert query._fields == ("test",) From 02fca2de8ef01252b4a0a96f700fb70a6bda8217 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 30 Aug 2024 16:28:14 -0400 Subject: [PATCH 05/10] add vectorrangequery and update sphinx docs --- docs/api/query.rst | 19 ++++++++++++++----- redisvl/query/__init__.py | 10 +++++++++- redisvl/query/query.py | 7 ------- tests/unit/test_query_types.py | 10 +++++----- 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/docs/api/query.rst b/docs/api/query.rst index a06d8bad..8086f5ca 100644 --- a/docs/api/query.rst +++ b/docs/api/query.rst @@ -3,7 +3,9 @@ Query ***** -.. _query_api: +Query classes in RedisVL provide a structured way to define simple or complex +queries for different use cases. Each query class wraps the ``redis-py`` Query module +https://github.com/redis/redis-py/blob/master/redis/commands/search/query.py with extended functionality for ease-of-use. VectorQuery @@ -15,19 +17,22 @@ VectorQuery .. autoclass:: VectorQuery :members: :inherited-members: + :show-inheritance: + :exclude-members: add_filter,get_args,highlight,return_field,summarize -RangeQuery -========== +VectorRangeQuery +================ .. currentmodule:: redisvl.query -.. autoclass:: RangeQuery +.. autoclass:: VectorRangeQuery :members: :inherited-members: - + :show-inheritance: + :exclude-members: add_filter,get_args,highlight,return_field,summarize FilterQuery =========== @@ -39,6 +44,8 @@ FilterQuery .. autoclass:: FilterQuery :members: :inherited-members: + :show-inheritance: + :exclude-members: add_filter,get_args,highlight,return_field,summarize @@ -51,3 +58,5 @@ CountQuery .. autoclass:: CountQuery :members: :inherited-members: + :show-inheritance: + :exclude-members: add_filter,get_args,highlight,return_field,summarize diff --git a/redisvl/query/__init__.py b/redisvl/query/__init__.py index 68182e0f..ecae6bad 100644 --- a/redisvl/query/__init__.py +++ b/redisvl/query/__init__.py @@ -4,6 +4,14 @@ FilterQuery, RangeQuery, VectorQuery, + VectorRangeQuery, ) -__all__ = ["BaseQuery", "VectorQuery", "FilterQuery", "RangeQuery", "CountQuery"] +__all__ = [ + "BaseQuery", + "VectorQuery", + "FilterQuery", + "RangeQuery", + "VectorRangeQuery", + "CountQuery" +] diff --git a/redisvl/query/query.py b/redisvl/query/query.py index c360c275..e111046a 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -50,13 +50,6 @@ def set_filter(self, filter_expression: Optional[FilterExpression] = None): # Reset the query string self._query_string = self._build_query_string() - def get_filter(self) -> FilterExpression: - """Get the filter expression for the query. - - Returns: - FilterExpression: The filter for the query. - """ - return self.filter @property def filter(self) -> FilterExpression: diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py index 05363ca7..00081ea5 100644 --- a/tests/unit/test_query_types.py +++ b/tests/unit/test_query_types.py @@ -28,7 +28,7 @@ def test_count_query(): # Test set_filter functionality new_filter_expression = Tag("category") == "Sportswear" count_query.set_filter(new_filter_expression) - assert count_query.get_filter() == new_filter_expression + assert count_query.filter == new_filter_expression fake_result = Result([2], "") assert process_results(fake_result, count_query, "json") == 2 @@ -43,7 +43,7 @@ def test_filter_query(): # Check properties assert filter_query._return_fields == return_fields assert filter_query._num_results == 10 - assert filter_query.get_filter() == filter_expression + assert filter_query.filter == filter_expression assert isinstance(filter_query, Query) assert isinstance(filter_query.query, Query) assert isinstance(filter_query.params, dict) @@ -55,7 +55,7 @@ def test_filter_query(): # Test set_filter functionality new_filter_expression = Tag("category") == "Sportswear" filter_query.set_filter(new_filter_expression) - assert filter_query.get_filter() == new_filter_expression + assert filter_query.filter == new_filter_expression # Test paging functionality filter_query.paging(5, 7) @@ -98,7 +98,7 @@ def test_vector_query(): # Test set_filter functionality new_filter_expression = Tag("category") == "Sportswear" vector_query.set_filter(new_filter_expression) - assert vector_query.get_filter() == new_filter_expression + assert vector_query.filter == new_filter_expression # Test paging functionality vector_query.paging(5, 7) @@ -156,7 +156,7 @@ def test_range_query(): # Test set_filter functionality new_filter_expression = Tag("category") == "Outdoor" range_query.set_filter(new_filter_expression) - assert range_query.get_filter() == new_filter_expression + assert range_query.filter == new_filter_expression # Test paging functionality range_query.paging(5, 7) From 156a14bf3f8488834eb3f61fabe1e8e949d59474 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 30 Aug 2024 16:47:54 -0400 Subject: [PATCH 06/10] formatting and linting --- redisvl/index/index.py | 2 +- redisvl/query/__init__.py | 2 +- redisvl/query/query.py | 15 ++++++++------- tests/integration/test_llmcache.py | 4 +++- tests/integration/test_session_manager.py | 3 +-- tests/unit/test_query_types.py | 8 ++++---- 6 files changed, 18 insertions(+), 16 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 4b2645a4..d5aeb5e2 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -65,7 +65,7 @@ def process_results( unpack_json = ( (storage_type == StorageType.JSON) and isinstance(query, FilterQuery) - and not query._return_fields + and not query._return_fields # type: ignore ) # Process records diff --git a/redisvl/query/__init__.py b/redisvl/query/__init__.py index ecae6bad..8246794f 100644 --- a/redisvl/query/__init__.py +++ b/redisvl/query/__init__.py @@ -13,5 +13,5 @@ "FilterQuery", "RangeQuery", "VectorRangeQuery", - "CountQuery" + "CountQuery", ] diff --git a/redisvl/query/query.py b/redisvl/query/query.py index e111046a..9ba05481 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -28,7 +28,7 @@ def __str__(self) -> str: def _build_query_string(self) -> str: """Build the full Redis query string.""" - pass + raise NotImplementedError("Must be implemented by subclasses") def set_filter(self, filter_expression: Optional[FilterExpression] = None): """Set the filter expression for the query. @@ -45,12 +45,13 @@ def set_filter(self, filter_expression: Optional[FilterExpression] = None): elif isinstance(filter_expression, FilterExpression): self._filter_expression = filter_expression else: - raise TypeError("filter_expression must be of type FilterExpression or None") + raise TypeError( + "filter_expression must be of type FilterExpression or None" + ) # Reset the query string self._query_string = self._build_query_string() - @property def filter(self) -> FilterExpression: """The filter expression for the query.""" @@ -167,8 +168,8 @@ def _build_query_string(self) -> str: return str(self._filter_expression) -class BaseVectorQuery(BaseQuery): - DTYPES: Dict[str, np.dtype] = { +class BaseVectorQuery: + DTYPES: Dict[str, Any] = { "float32": np.float32, "float64": np.float64, } @@ -176,7 +177,7 @@ class BaseVectorQuery(BaseQuery): VECTOR_PARAM: str = "vector" -class VectorQuery(BaseVectorQuery): +class VectorQuery(BaseVectorQuery, BaseQuery): def __init__( self, vector: Union[List[float], bytes], @@ -268,7 +269,7 @@ def params(self) -> Dict[str, Any]: return {self.VECTOR_PARAM: vector} -class VectorRangeQuery(BaseVectorQuery): +class VectorRangeQuery(BaseVectorQuery, BaseQuery): DISTANCE_THRESHOLD_PARAM: str = "distance_threshold" def __init__( diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 9252f2bd..f1c6e262 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -19,7 +19,9 @@ def vectorizer(): @pytest.fixture def cache(vectorizer, redis_url): cache_instance = SemanticCache( - vectorizer=vectorizer, distance_threshold=0.2, redis_url="redis://localhost:6379" + vectorizer=vectorizer, + distance_threshold=0.2, + redis_url="redis://localhost:6379", ) yield cache_instance cache_instance._index.delete(True) # Clean up index diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index 56943447..20c2955d 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -464,8 +464,7 @@ def test_semantic_add_and_get_relevant(semantic_session): default_context = semantic_session.get_relevant("list of fruits and vegetables") assert len(default_context) == 5 # 2 pairs of prompt:response, and system assert default_context == semantic_session.get_relevant( - "list of fruits and vegetables", - distance_threshold=0.5 + "list of fruits and vegetables", distance_threshold=0.5 ) # test tool calls can also be returned diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py index 00081ea5..17426868 100644 --- a/tests/unit/test_query_types.py +++ b/tests/unit/test_query_types.py @@ -1,5 +1,4 @@ import pytest - from redis.commands.search.query import Query from redis.commands.search.result import Result @@ -124,10 +123,11 @@ def test_vector_query(): ["field1", "field2"], dialect=3, num_results=10, - in_order=True + in_order=True, ) assert vector_query._in_order + def test_range_query(): # Create a filter expression filter_expression = Tag("brand") == "Nike" @@ -182,7 +182,7 @@ def test_range_query(): ["field1"], filter_expression, num_results=10, - in_order=True + in_order=True, ) assert range_query._in_order @@ -193,7 +193,7 @@ def test_range_query(): CountQuery(), FilterQuery(), VectorQuery(vector=[1, 2, 3], vector_field_name="vector"), - RangeQuery(vector=[1, 2, 3], vector_field_name="vector") + RangeQuery(vector=[1, 2, 3], vector_field_name="vector"), ], ) def test_query_modifiers(query): From a2f918febae9d4259f06af2ac8b48f18fc99a2e3 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 6 Sep 2024 14:17:30 -0400 Subject: [PATCH 07/10] update query notebook examples --- docs/user_guide/hybrid_queries_02.ipynb | 195 +++++++++++++----------- 1 file changed, 107 insertions(+), 88 deletions(-) diff --git a/docs/user_guide/hybrid_queries_02.ipynb b/docs/user_guide/hybrid_queries_02.ipynb index e47afc1f..d09be7db 100644 --- a/docs/user_guide/hybrid_queries_02.ipynb +++ b/docs/user_guide/hybrid_queries_02.ipynb @@ -76,7 +76,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -94,15 +94,15 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m08:48:47\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m08:48:47\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n" + "\u001b[32m14:06:19\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m14:06:19\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n" ] } ], @@ -113,10 +113,11 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ + "# load data to redis\n", "keys = index.load(data)" ] }, @@ -141,7 +142,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -175,7 +176,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -201,7 +202,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -227,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -264,7 +265,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -299,7 +300,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -326,7 +327,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -352,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -387,7 +388,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -415,7 +416,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -441,7 +442,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -467,7 +468,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -493,7 +494,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -519,7 +520,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -554,7 +555,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -582,7 +583,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -608,7 +609,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -645,7 +646,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -688,7 +689,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -732,7 +733,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -747,7 +748,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -772,7 +773,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -797,7 +798,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -822,7 +823,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -856,7 +857,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -898,7 +899,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 33, "metadata": {}, "outputs": [ { @@ -932,7 +933,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 34, "metadata": {}, "outputs": [ { @@ -973,7 +974,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 35, "metadata": {}, "outputs": [ { @@ -1004,7 +1005,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 36, "metadata": {}, "outputs": [ { @@ -1032,24 +1033,25 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Other Redis Queries\n", - "\n", - "Sometimes there may be a case where RedisVL does not cover the explicit functionality required by the query either because of new releases that haven't been implemented in the client, or because of a very specific use case. In these cases, it is possible to use the ``SearchIndex.search`` method to execute query with a redis-py ``Query`` object or through a raw redis string.\n", - "\n", - "For example\n", - "\n", - "### Redis-Py" + "## Advanced Query Modifiers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See all modifier options available on the query API docs: https://www.redisvl.com/api/query.html" ] }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceageusercredit_scorejoboffice_location
0.109129190445100tylerhighengineer-122.0839,37.3861
0.26666665077294nancyhighdoctor-122.4194,37.7749
0.65330135822335joemediumdentist-122.0839,37.3861
018johnhighengineer-122.4194,37.7749
0.21788203716315taimurlowCEO-122.0839,37.3861
014derricklowdoctor-122.4194,37.7749
0.15880894660912timhighdermatologist-122.0839,37.3861
" + "
vector_distanceageusercredit_scorejoboffice_location
0.109129190445100tylerhighengineer-122.0839,37.3861
018johnhighengineer-122.4194,37.7749
" ], "text/plain": [ "" @@ -1060,14 +1062,16 @@ } ], "source": [ - "# Manipulate the Redis-py Query object\n", - "redis_py_query = v.query\n", - "\n", - "# choose to sort by age instead of vector distance\n", - "redis_py_query.sort_by(\"age\", asc=False)\n", + "# Sort by a different field and change dialect\n", + "v = VectorQuery(\n", + " vector=[0.1, 0.1, 0.5],\n", + " vector_field_name=\"user_embedding\",\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", + " num_results=5,\n", + " filter_expression=is_engineer\n", + ").sort_by(\"age\", asc=False).dialect(3)\n", "\n", - "# run the query with the ``SearchIndex.search`` method\n", - "result = index.search(redis_py_query, v.params)\n", + "result = index.query(v)\n", "result_print(result)" ] }, @@ -1077,76 +1081,64 @@ "source": [ "### Raw Redis Query String\n", "\n", - "So one case might be where you simply want to have a search that only filters on a tag field and don't need other functionality. Conversely, you may need to have a query that is more complex than what is currently supported by RedisVL. In these cases, you can use the ``SearchIndex.search`` again with just a raw redis query string." + "Sometimes it's helpful to convert these classes into their raw Redis query strings." ] }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'@credit_score:{high}'" + "'@job:(\"engineer\")=>[KNN 5 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY age DESC DIALECT 3 LIMIT 0 5'" ] }, - "execution_count": 36, + "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "t = Tag(\"credit_score\") == \"high\"\n", - "\n", - "str(t)" + "# check out the complex query from above\n", + "str(v)" ] }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 36, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'id': 'user_queries_docs:0e511391dcf346639669bdba70a189c0', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:d204e8e5df90467dbff5b2fb6f800a78', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:7cf3d6b1a4044966b4f0c5d3725a5e03', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:f6581edaaeaf432a85c1d1df8fdf5edc', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" - ] + "data": { + "text/plain": [ + "'@credit_score:{high}'" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "results = index.search(str(t))\n", - "for r in results.docs:\n", - " print(r.__dict__)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Inspecting Queries\n", - "\n", - "In this example, we will show how to inspect the query that is generated by RedisVL. This can be useful for debugging purposes or for understanding how the query is being executed.\n", + "t = Tag(\"credit_score\") == \"high\"\n", "\n", - "Let's again take the example of a query that combines a numeric filter with a tag filter. This will search for users that are between the ages of between 18 and 100, have a high credit score, and sort by closest vector distance to the query vector." + "str(t)" ] }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'((@credit_score:{high} @age:[18 +inf]) @age:[-inf 100])=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10'" + "'((@credit_score:{high} @age:[18 +inf]) @age:[-inf 100])'" ] }, - "execution_count": 38, + "execution_count": 48, "metadata": {}, "output_type": "execute_result" } @@ -1158,15 +1150,42 @@ "\n", "combined = t & low & high\n", "\n", - "v.set_filter(combined)\n", - "\n", - "# Using the str() method, you can see what Redis Query this will emit.\n", - "str(v)" + "str(combined)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The RedisVL `SearchIndex` class exposes a `search()` method which is a simple wrapper around the `FT.SEARCH` API.\n", + "Provide any valid Redis query string." + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'id': 'user_queries_docs:43dc726b8a9541a6ab40ddedc8e48657', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:93fdc65248a64fd390ed77aa3c248c23', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:f1d1f69e5e6c41cb9b7ae70ed8f75da5', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:5dc68e47ef6d4a0f885c67368f0710b7', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" + ] + } + ], + "source": [ + "results = index.search(str(t))\n", + "for r in results.docs:\n", + " print(r.__dict__)" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ @@ -1191,7 +1210,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.11.9" }, "orig_nbformat": 4, "vscode": { From 475d424d27f4abf73769f85019dcc30ecb9dd9a1 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 6 Sep 2024 14:28:11 -0400 Subject: [PATCH 08/10] port the sortby support to session manager internals --- .../session_manager/semantic_session.py | 16 ++++------------ .../session_manager/standard_session.py | 16 ++++------------ 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 773f3fc5..f5f4c37b 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -120,12 +120,8 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: filter_expression=self._default_session_filter, return_fields=return_fields, ) - - sorted_query = query.query - sorted_query.sort_by(self.timestamp_field_name, asc=True) - messages = [ - doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs - ] + query.sort_by(self.timestamp_field_name, asc=True) + messages = self._index.query(query) return self._format_context(messages, as_text=False) @@ -255,12 +251,8 @@ def get_recent( return_fields=return_fields, num_results=top_k, ) - - sorted_query = query.query - sorted_query.sort_by(self.timestamp_field_name, asc=False) - messages = [ - doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs - ] + query.sort_by(self.timestamp_field_name, asc=False) + messages = self._index.query(query) if raw: return messages[::-1] diff --git a/redisvl/extensions/session_manager/standard_session.py b/redisvl/extensions/session_manager/standard_session.py index 37d18a3e..9ecfbb5d 100644 --- a/redisvl/extensions/session_manager/standard_session.py +++ b/redisvl/extensions/session_manager/standard_session.py @@ -103,12 +103,8 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: filter_expression=self._default_session_filter, return_fields=return_fields, ) - - sorted_query = query.query - sorted_query.sort_by(self.timestamp_field_name, asc=True) - messages = [ - doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs - ] + query.sort_by(self.timestamp_field_name, asc=True) + messages = self._index.query(query) return self._format_context(messages, as_text=False) @@ -160,12 +156,8 @@ def get_recent( return_fields=return_fields, num_results=top_k, ) - - sorted_query = query.query - sorted_query.sort_by(self.timestamp_field_name, asc=False) - messages = [ - doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs - ] + query.sort_by(self.timestamp_field_name, asc=False) + messages = self._index.query(query) if raw: return messages[::-1] From 42192ae1b61e396473a54b950266761b92a7c4fa Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 6 Sep 2024 14:29:51 -0400 Subject: [PATCH 09/10] small formatting updates --- redisvl/index/index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index d5aeb5e2..5a4fd2df 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -65,7 +65,7 @@ def process_results( unpack_json = ( (storage_type == StorageType.JSON) and isinstance(query, FilterQuery) - and not query._return_fields # type: ignore + and not query._return_fields # type: ignore ) # Process records From 4e87c68f7cd2ea96b28866ffbf941b08204f1176 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 6 Sep 2024 14:32:53 -0400 Subject: [PATCH 10/10] update redis url for test --- tests/integration/test_llmcache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index f1c6e262..09c9327b 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -21,7 +21,7 @@ def cache(vectorizer, redis_url): cache_instance = SemanticCache( vectorizer=vectorizer, distance_threshold=0.2, - redis_url="redis://localhost:6379", + redis_url=redis_url, ) yield cache_instance cache_instance._index.delete(True) # Clean up index