diff --git a/docs/api/query.rst b/docs/api/query.rst index a06d8bad..8086f5ca 100644 --- a/docs/api/query.rst +++ b/docs/api/query.rst @@ -3,7 +3,9 @@ Query ***** -.. _query_api: +Query classes in RedisVL provide a structured way to define simple or complex +queries for different use cases. Each query class wraps the ``redis-py`` Query module +https://github.com/redis/redis-py/blob/master/redis/commands/search/query.py with extended functionality for ease-of-use. VectorQuery @@ -15,19 +17,22 @@ VectorQuery .. autoclass:: VectorQuery :members: :inherited-members: + :show-inheritance: + :exclude-members: add_filter,get_args,highlight,return_field,summarize -RangeQuery -========== +VectorRangeQuery +================ .. currentmodule:: redisvl.query -.. autoclass:: RangeQuery +.. autoclass:: VectorRangeQuery :members: :inherited-members: - + :show-inheritance: + :exclude-members: add_filter,get_args,highlight,return_field,summarize FilterQuery =========== @@ -39,6 +44,8 @@ FilterQuery .. autoclass:: FilterQuery :members: :inherited-members: + :show-inheritance: + :exclude-members: add_filter,get_args,highlight,return_field,summarize @@ -51,3 +58,5 @@ CountQuery .. autoclass:: CountQuery :members: :inherited-members: + :show-inheritance: + :exclude-members: add_filter,get_args,highlight,return_field,summarize diff --git a/docs/user_guide/hybrid_queries_02.ipynb b/docs/user_guide/hybrid_queries_02.ipynb index e47afc1f..d09be7db 100644 --- a/docs/user_guide/hybrid_queries_02.ipynb +++ b/docs/user_guide/hybrid_queries_02.ipynb @@ -76,7 +76,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -94,15 +94,15 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m08:48:47\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m08:48:47\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n" + "\u001b[32m14:06:19\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m14:06:19\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n" ] } ], @@ -113,10 +113,11 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ + "# load data to redis\n", "keys = index.load(data)" ] }, @@ -141,7 +142,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -175,7 +176,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -201,7 +202,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -227,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -264,7 +265,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -299,7 +300,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -326,7 +327,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -352,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -387,7 +388,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -415,7 +416,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -441,7 +442,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -467,7 +468,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -493,7 +494,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -519,7 +520,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -554,7 +555,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -582,7 +583,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -608,7 +609,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -645,7 +646,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -688,7 +689,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -732,7 +733,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -747,7 +748,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -772,7 +773,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -797,7 +798,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -822,7 +823,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -856,7 +857,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -898,7 +899,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 33, "metadata": {}, "outputs": [ { @@ -932,7 +933,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 34, "metadata": {}, "outputs": [ { @@ -973,7 +974,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 35, "metadata": {}, "outputs": [ { @@ -1004,7 +1005,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 36, "metadata": {}, "outputs": [ { @@ -1032,24 +1033,25 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Other Redis Queries\n", - "\n", - "Sometimes there may be a case where RedisVL does not cover the explicit functionality required by the query either because of new releases that haven't been implemented in the client, or because of a very specific use case. In these cases, it is possible to use the ``SearchIndex.search`` method to execute query with a redis-py ``Query`` object or through a raw redis string.\n", - "\n", - "For example\n", - "\n", - "### Redis-Py" + "## Advanced Query Modifiers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See all modifier options available on the query API docs: https://www.redisvl.com/api/query.html" ] }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceageusercredit_scorejoboffice_location
0.109129190445100tylerhighengineer-122.0839,37.3861
0.26666665077294nancyhighdoctor-122.4194,37.7749
0.65330135822335joemediumdentist-122.0839,37.3861
018johnhighengineer-122.4194,37.7749
0.21788203716315taimurlowCEO-122.0839,37.3861
014derricklowdoctor-122.4194,37.7749
0.15880894660912timhighdermatologist-122.0839,37.3861
" + "
vector_distanceageusercredit_scorejoboffice_location
0.109129190445100tylerhighengineer-122.0839,37.3861
018johnhighengineer-122.4194,37.7749
" ], "text/plain": [ "" @@ -1060,14 +1062,16 @@ } ], "source": [ - "# Manipulate the Redis-py Query object\n", - "redis_py_query = v.query\n", - "\n", - "# choose to sort by age instead of vector distance\n", - "redis_py_query.sort_by(\"age\", asc=False)\n", + "# Sort by a different field and change dialect\n", + "v = VectorQuery(\n", + " vector=[0.1, 0.1, 0.5],\n", + " vector_field_name=\"user_embedding\",\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", + " num_results=5,\n", + " filter_expression=is_engineer\n", + ").sort_by(\"age\", asc=False).dialect(3)\n", "\n", - "# run the query with the ``SearchIndex.search`` method\n", - "result = index.search(redis_py_query, v.params)\n", + "result = index.query(v)\n", "result_print(result)" ] }, @@ -1077,76 +1081,64 @@ "source": [ "### Raw Redis Query String\n", "\n", - "So one case might be where you simply want to have a search that only filters on a tag field and don't need other functionality. Conversely, you may need to have a query that is more complex than what is currently supported by RedisVL. In these cases, you can use the ``SearchIndex.search`` again with just a raw redis query string." + "Sometimes it's helpful to convert these classes into their raw Redis query strings." ] }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'@credit_score:{high}'" + "'@job:(\"engineer\")=>[KNN 5 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY age DESC DIALECT 3 LIMIT 0 5'" ] }, - "execution_count": 36, + "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "t = Tag(\"credit_score\") == \"high\"\n", - "\n", - "str(t)" + "# check out the complex query from above\n", + "str(v)" ] }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 36, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'id': 'user_queries_docs:0e511391dcf346639669bdba70a189c0', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:d204e8e5df90467dbff5b2fb6f800a78', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:7cf3d6b1a4044966b4f0c5d3725a5e03', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:f6581edaaeaf432a85c1d1df8fdf5edc', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" - ] + "data": { + "text/plain": [ + "'@credit_score:{high}'" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "results = index.search(str(t))\n", - "for r in results.docs:\n", - " print(r.__dict__)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Inspecting Queries\n", - "\n", - "In this example, we will show how to inspect the query that is generated by RedisVL. This can be useful for debugging purposes or for understanding how the query is being executed.\n", + "t = Tag(\"credit_score\") == \"high\"\n", "\n", - "Let's again take the example of a query that combines a numeric filter with a tag filter. This will search for users that are between the ages of between 18 and 100, have a high credit score, and sort by closest vector distance to the query vector." + "str(t)" ] }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'((@credit_score:{high} @age:[18 +inf]) @age:[-inf 100])=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10'" + "'((@credit_score:{high} @age:[18 +inf]) @age:[-inf 100])'" ] }, - "execution_count": 38, + "execution_count": 48, "metadata": {}, "output_type": "execute_result" } @@ -1158,15 +1150,42 @@ "\n", "combined = t & low & high\n", "\n", - "v.set_filter(combined)\n", - "\n", - "# Using the str() method, you can see what Redis Query this will emit.\n", - "str(v)" + "str(combined)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The RedisVL `SearchIndex` class exposes a `search()` method which is a simple wrapper around the `FT.SEARCH` API.\n", + "Provide any valid Redis query string." + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'id': 'user_queries_docs:43dc726b8a9541a6ab40ddedc8e48657', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:93fdc65248a64fd390ed77aa3c248c23', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:f1d1f69e5e6c41cb9b7ae70ed8f75da5', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:5dc68e47ef6d4a0f885c67368f0710b7', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" + ] + } + ], + "source": [ + "results = index.search(str(t))\n", + "for r in results.docs:\n", + " print(r.__dict__)" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ @@ -1191,7 +1210,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.11.9" }, "orig_nbformat": 4, "vscode": { diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 17856196..e7e0b290 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -296,14 +296,9 @@ def check( # overrides distance_threshold = distance_threshold or self._distance_threshold - return_fields = return_fields or self.return_fields vector = vector or self._vectorize_prompt(prompt) - self._check_vector_dims(vector) - if not isinstance(return_fields, list): - raise TypeError("return_fields must be a list of field names") - query = RangeQuery( vector=vector, vector_field_name=self.vector_field_name, @@ -320,15 +315,18 @@ def check( cache_search_results = self._index.query(query) for cache_search_result in cache_search_results: - key = cache_search_result["id"] - self._refresh_ttl(key) + redis_key = cache_search_result.pop("id") + self._refresh_ttl(redis_key) - # Create cache hit + # Create and process cache hit cache_hit = CacheHit(**cache_search_result) - cache_hit_dict = { - k: v for k, v in cache_hit.to_dict().items() if k in return_fields - } - cache_hit_dict["key"] = key + cache_hit_dict = cache_hit.to_dict() + # Filter down to only selected return fields if needed + if isinstance(return_fields, list) and len(return_fields) > 0: + cache_hit_dict = { + k: v for k, v in cache_hit_dict.items() if k in return_fields + } + cache_hit_dict[self.redis_key_field_name] = redis_key cache_hits.append(cache_hit_dict) return cache_hits diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 773f3fc5..f5f4c37b 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -120,12 +120,8 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: filter_expression=self._default_session_filter, return_fields=return_fields, ) - - sorted_query = query.query - sorted_query.sort_by(self.timestamp_field_name, asc=True) - messages = [ - doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs - ] + query.sort_by(self.timestamp_field_name, asc=True) + messages = self._index.query(query) return self._format_context(messages, as_text=False) @@ -255,12 +251,8 @@ def get_recent( return_fields=return_fields, num_results=top_k, ) - - sorted_query = query.query - sorted_query.sort_by(self.timestamp_field_name, asc=False) - messages = [ - doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs - ] + query.sort_by(self.timestamp_field_name, asc=False) + messages = self._index.query(query) if raw: return messages[::-1] diff --git a/redisvl/extensions/session_manager/standard_session.py b/redisvl/extensions/session_manager/standard_session.py index 37d18a3e..9ecfbb5d 100644 --- a/redisvl/extensions/session_manager/standard_session.py +++ b/redisvl/extensions/session_manager/standard_session.py @@ -103,12 +103,8 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: filter_expression=self._default_session_filter, return_fields=return_fields, ) - - sorted_query = query.query - sorted_query.sort_by(self.timestamp_field_name, asc=True) - messages = [ - doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs - ] + query.sort_by(self.timestamp_field_name, asc=True) + messages = self._index.query(query) return self._format_context(messages, as_text=False) @@ -160,12 +156,8 @@ def get_recent( return_fields=return_fields, num_results=top_k, ) - - sorted_query = query.query - sorted_query.sort_by(self.timestamp_field_name, asc=False) - messages = [ - doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs - ] + query.sort_by(self.timestamp_field_name, asc=False) + messages = self._index.query(query) if raw: return messages[::-1] diff --git a/redisvl/index/index.py b/redisvl/index/index.py index f5e6b4a6..5a4fd2df 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -65,7 +65,7 @@ def process_results( unpack_json = ( (storage_type == StorageType.JSON) and isinstance(query, FilterQuery) - and not query._return_fields + and not query._return_fields # type: ignore ) # Process records @@ -710,7 +710,7 @@ def paginate(self, query: BaseQuery, page_size: int = 30) -> Generator: offset = 0 while True: - query.set_paging(offset, page_size) + query.paging(offset, page_size) results = self._query(query) if not results: break @@ -1194,7 +1194,7 @@ async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerato first = 0 while True: - query.set_paging(first, page_size) + query.paging(first, page_size) results = await self._query(query) if not results: break diff --git a/redisvl/query/__init__.py b/redisvl/query/__init__.py index 68182e0f..8246794f 100644 --- a/redisvl/query/__init__.py +++ b/redisvl/query/__init__.py @@ -4,6 +4,14 @@ FilterQuery, RangeQuery, VectorQuery, + VectorRangeQuery, ) -__all__ = ["BaseQuery", "VectorQuery", "FilterQuery", "RangeQuery", "CountQuery"] +__all__ = [ + "BaseQuery", + "VectorQuery", + "FilterQuery", + "RangeQuery", + "VectorRangeQuery", + "CountQuery", +] diff --git a/redisvl/query/query.py b/redisvl/query/query.py index a1b3832b..9ba05481 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -1,237 +1,183 @@ from typing import Any, Dict, List, Optional, Union import numpy as np -from redis.commands.search.query import Query +from redis.commands.search.query import Query as RedisQuery from redisvl.query.filter import FilterExpression from redisvl.redis.utils import array_to_buffer -class BaseQuery: - def __init__( - self, - return_fields: Optional[List[str]] = None, - num_results: int = 10, - dialect: int = 2, - sort_by: Optional[str] = None, - in_order: bool = False, - ): - """Base query class used to subclass many query types.""" - self._return_fields = return_fields if return_fields is not None else [] - self._num_results = num_results - self._dialect = dialect - self._first = 0 - self._limit = num_results - self._sort_by = sort_by - self._in_order = in_order +class BaseQuery(RedisQuery): + """Base query class used to subclass many query types.""" + + _params: Dict[str, Any] = {} + _filter_expression: FilterExpression = FilterExpression("*") + + def __init__(self, query_string: str = "*"): + """ + Initialize the BaseQuery class. + + Args: + query_string (str, optional): The query string to use. Defaults to '*'. + """ + super().__init__(query_string) def __str__(self) -> str: - return " ".join([str(x) for x in self.query.get_args()]) + """Return the string representation of the query.""" + return " ".join([str(x) for x in self.get_args()]) + + def _build_query_string(self) -> str: + """Build the full Redis query string.""" + raise NotImplementedError("Must be implemented by subclasses") def set_filter(self, filter_expression: Optional[FilterExpression] = None): """Set the filter expression for the query. Args: - filter_expression (Optional[FilterExpression], optional): The filter - to apply to the query. + filter_expression (Optional[FilterExpression], optional): The filter to apply to the query. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression """ if filter_expression is None: # Default filter to match everything - self._filter = FilterExpression("*") + self._filter_expression = FilterExpression("*") elif isinstance(filter_expression, FilterExpression): - self._filter = filter_expression + self._filter_expression = filter_expression else: raise TypeError( "filter_expression must be of type FilterExpression or None" ) - def get_filter(self) -> FilterExpression: - """Get the filter expression for the query. + # Reset the query string + self._query_string = self._build_query_string() - Returns: - FilterExpression: The filter for the query. - """ - return self._filter - - def set_paging(self, first: int, limit: int): - """Set the paging parameters for the query to limit the number of - results. - - Args: - first (int): The zero-indexed offset for which to fetch query results - limit (int): The max number of results to include including the offset - - Raises: - TypeError: If first or limit are NOT integers. - """ - if not isinstance(first, int) or not isinstance(limit, int): - raise TypeError("Paging params must both be integers") - - self._first = first - self._limit = limit + @property + def filter(self) -> FilterExpression: + """The filter expression for the query.""" + return self._filter_expression @property - def query(self) -> Query: - raise NotImplementedError + def query(self) -> "BaseQuery": + """Return self as the query object.""" + return self @property def params(self) -> Dict[str, Any]: - return {} + """Return the query parameters.""" + return self._params -class CountQuery(BaseQuery): +class FilterQuery(BaseQuery): def __init__( self, - filter_expression: FilterExpression, + filter_expression: Optional[FilterExpression] = None, + return_fields: Optional[List[str]] = None, + num_results: int = 10, dialect: int = 2, + sort_by: Optional[str] = None, + in_order: bool = False, params: Optional[Dict[str, Any]] = None, ): - """A query for a simple count operation provided some filter expression. + """A query for running a filtered search with a filter expression. Args: - filter_expression (FilterExpression): The filter expression to query for. + filter_expression (Optional[FilterExpression]): The optional filter + expression to query with. Defaults to '*'. + return_fields (Optional[List[str]], optional): The fields to return. + num_results (Optional[int], optional): The number of results to return. Defaults to 10. + dialect (int, optional): The query dialect. Defaults to 2. + sort_by (Optional[str], optional): The field to order the results by. Defaults to None. + in_order (bool, optional): Requires the terms in the field to have the same order as the terms in the query filter. Defaults to False. params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression + """ + if filter_expression: + self._filter_expression = filter_expression + if params: + self._params = params - .. code-block:: python + self._num_results = num_results - from redisvl.query import CountQuery - from redisvl.query.filter import Tag + # Initialize the base query with the full query string constructed from the filter expression + query_string = self._build_query_string() + super().__init__(query_string) - t = Tag("brand") == "Nike" - query = CountQuery(filter_expression=t) + # Handle query settings + if return_fields: + self.return_fields(*return_fields) + self.paging(0, self._num_results).dialect(dialect) - count = index.query(query) - """ - super().__init__(num_results=0, dialect=dialect) - self.set_filter(filter_expression) - self._params = params or {} + if sort_by: + self.sort_by(sort_by) - @property - def query(self) -> Query: - """The loaded Redis-Py query. + if in_order: + self.in_order() - Returns: - redis.commands.search.query.Query: The Redis-Py query object. - """ - base_query = str(self._filter) - query = Query(base_query).no_content().paging(0, 0).dialect(self._dialect) - return query + def _build_query_string(self) -> str: + """Build the full query string based on the filter and other components.""" + # Example logic to build the full query string from filter and other parts + # This can be customized in child classes for more complex queries + return str(self._filter_expression) - @property - def params(self) -> Dict[str, Any]: - """The parameters for the query. - Returns: - Dict[str, Any]: The parameters for the query. - """ - return self._params - - -class FilterQuery(BaseQuery): +class CountQuery(BaseQuery): def __init__( self, - filter_expression: FilterExpression, - return_fields: Optional[List[str]] = None, - num_results: int = 10, + filter_expression: Optional[FilterExpression] = None, dialect: int = 2, - sort_by: Optional[str] = None, - in_order: bool = False, params: Optional[Dict[str, Any]] = None, ): - """A query for a running a filtered search with a filter expression. + """A query for a simple count operation provided some filter expression. Args: - filter_expression (FilterExpression): The filter expression to - query for. - return_fields (Optional[List[str]], optional): The fields to return. - num_results (Optional[int], optional): The number of results to - return. Defaults to 10. - sort_by (Optional[str]): The field to order the results by. Defaults - to None. Results will be ordered by vector distance. - in_order (bool): Requires the terms in the field to have - the same order as the terms in the query filter, regardless of - the offsets between them. Defaults to False. - params (Optional[Dict[str, Any]], optional): The parameters for the - query. Defaults to None. + filter_expression (Optional[FilterExpression]): The filter expression to query with. Defaults to None. + params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression .. code-block:: python - - from redisvl.query import FilterQuery + from redisvl.query import CountQuery from redisvl.query.filter import Tag t = Tag("brand") == "Nike" - q = FilterQuery(return_fields=["brand", "price"], filter_expression=t) + query = CountQuery(filter_expression=t) + count = index.query(query) """ - super().__init__(return_fields, num_results, dialect, sort_by, in_order) - self.set_filter(filter_expression) - self._params = params or {} + if filter_expression: + self._filter_expression = filter_expression + if params: + self._params = params - @property - def query(self) -> Query: - """Return a Redis-Py Query object representing the query. - - Returns: - redis.commands.search.query.Query: The Redis-Py query object. - """ - base_query = str(self._filter) - query = ( - Query(base_query) - .return_fields(*self._return_fields) - .paging(self._first, self._limit) - .dialect(self._dialect) - ) - if self._sort_by: - query = query.sort_by(self._sort_by) + # Initialize the base query with the full query string constructed from the filter expression + query_string = self._build_query_string() + super().__init__(query_string) - if self._in_order: - query = query.in_order() + # Query specific modifications + self.no_content().paging(0, 0).dialect(dialect) - return query + def _build_query_string(self) -> str: + """Build the full query string based on the filter and other components.""" + # Example logic to build the full query string from filter and other parts + # This can be customized in child classes for more complex queries + return str(self._filter_expression) -class BaseVectorQuery(BaseQuery): - DTYPES = { +class BaseVectorQuery: + DTYPES: Dict[str, Any] = { "float32": np.float32, "float64": np.float64, } - DISTANCE_ID = "vector_distance" - VECTOR_PARAM = "vector" + DISTANCE_ID: str = "vector_distance" + VECTOR_PARAM: str = "vector" - def __init__( - self, - vector: Union[List[float], bytes], - vector_field_name: str, - return_fields: Optional[List[str]] = None, - filter_expression: Optional[FilterExpression] = None, - dtype: str = "float32", - num_results: int = 10, - return_score: bool = True, - dialect: int = 2, - sort_by: Optional[str] = None, - in_order: bool = False, - ): - super().__init__(return_fields, num_results, dialect, sort_by, in_order) - self.set_filter(filter_expression) - self._vector = vector - self._field = vector_field_name - self._dtype = dtype.lower() - if return_score: - self._return_fields.append(self.DISTANCE_ID) - - -class VectorQuery(BaseVectorQuery): +class VectorQuery(BaseVectorQuery, BaseQuery): def __init__( self, vector: Union[List[float], bytes], @@ -276,42 +222,37 @@ def __init__( Note: Learn more about vector queries in Redis: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search """ - super().__init__( - vector, - vector_field_name, - return_fields, - filter_expression, - dtype, - num_results, - return_score, - dialect, - sort_by, - in_order, - ) + if filter_expression: + self._filter_expression = filter_expression - @property - def query(self) -> Query: - """Return a Redis-Py Query object representing the query. + self._vector = vector + self._vector_field_name = vector_field_name + self._dtype = dtype + self._num_results = num_results - Returns: - redis.commands.search.query.Query: The Redis-Py query object. - """ - base_query = f"{str(self._filter)}=>[KNN {self._num_results} @{self._field} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]" - query = ( - Query(base_query) - .return_fields(*self._return_fields) - .paging(self._first, self._limit) - .dialect(self._dialect) - ) - if self._sort_by: - query = query.sort_by(self._sort_by) + query_string = self._build_query_string() + super().__init__(query_string) + + # Handle query modifiers + if return_fields: + self.return_fields(*return_fields) + + self.paging(0, self._num_results).dialect(dialect) + + if return_score: + self.return_fields(self.DISTANCE_ID) + + if sort_by: + self.sort_by(sort_by) else: - query = query.sort_by(self.DISTANCE_ID) + self.sort_by(self.DISTANCE_ID) - if self._in_order: - query = query.in_order() + if in_order: + self.in_order() - return query + def _build_query_string(self) -> str: + """Build the full query string for vector search with optional filtering.""" + return f"{str(self._filter_expression)}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]" @property def params(self) -> Dict[str, Any]: @@ -321,15 +262,15 @@ def params(self) -> Dict[str, Any]: Dict[str, Any]: The parameters for the query. """ if isinstance(self._vector, bytes): - vector_param = self._vector + vector = self._vector else: - vector_param = array_to_buffer(self._vector, dtype=self.DTYPES[self._dtype]) + vector = array_to_buffer(self._vector, dtype=self.DTYPES[self._dtype]) - return {self.VECTOR_PARAM: vector_param} + return {self.VECTOR_PARAM: vector} -class RangeQuery(BaseVectorQuery): - DISTANCE_THRESHOLD_PARAM = "distance_threshold" +class VectorRangeQuery(BaseVectorQuery, BaseQuery): + DISTANCE_THRESHOLD_PARAM: str = "distance_threshold" def __init__( self, @@ -380,22 +321,48 @@ def __init__( Learn more about vector range queries: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query """ - super().__init__( - vector, - vector_field_name, - return_fields, - filter_expression, - dtype, - num_results, - return_score, - dialect, - sort_by, - in_order, - ) + if filter_expression: + self._filter_expression = filter_expression + + self._vector = vector + self._vector_field_name = vector_field_name + self._dtype = dtype + self._num_results = num_results self.set_distance_threshold(distance_threshold) + query_string = self._build_query_string() + super().__init__(query_string) + + # Handle query modifiers + if return_fields: + self.return_fields(*return_fields) + + self.paging(0, self._num_results).dialect(dialect) + + if return_score: + self.return_fields(self.DISTANCE_ID) + + if sort_by: + self.sort_by(sort_by) + else: + self.sort_by(self.DISTANCE_ID) + + if in_order: + self.in_order() + + def _build_query_string(self) -> str: + """Build the full query string for vector range queries with optional filtering""" + base_query = f"@{self._vector_field_name}:[VECTOR_RANGE ${self.DISTANCE_THRESHOLD_PARAM} ${self.VECTOR_PARAM}]" + _filter = str(self._filter_expression) + if _filter != "*": + return ( + f"({base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}} {_filter})" + ) + else: + return f"{base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}}" + def set_distance_threshold(self, distance_threshold: float): - """Set the distance treshold for the query. + """Set the distance threshold for the query. Args: distance_threshold (float): vector distance @@ -413,40 +380,6 @@ def distance_threshold(self) -> float: """ return self._distance_threshold - @property - def query(self) -> Query: - """Return a Redis-Py Query object representing the query. - - Returns: - redis.commands.search.query.Query: The Redis-Py query object. - """ - base_query = f"@{self._field}:[VECTOR_RANGE ${self.DISTANCE_THRESHOLD_PARAM} ${self.VECTOR_PARAM}]" - - _filter = str(self._filter) - - if _filter != "*": - base_query = ( - f"({base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}} {_filter})" - ) - else: - base_query = f"{base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}}" - - query = ( - Query(base_query) - .return_fields(*self._return_fields) - .paging(self._first, self._limit) - .dialect(self._dialect) - ) - if self._sort_by: - query = query.sort_by(self._sort_by) - else: - query = query.sort_by(self.DISTANCE_ID) - - if self._in_order: - query = query.in_order() - - return query - @property def params(self) -> Dict[str, Any]: """Return the parameters for the query. @@ -463,3 +396,8 @@ def params(self) -> Dict[str, Any]: self.VECTOR_PARAM: vector_param, self.DISTANCE_THRESHOLD_PARAM: self._distance_threshold, } + + +class RangeQuery(VectorRangeQuery): + # keep for backwards compatibility + pass diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 2263b745..09c9327b 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -19,7 +19,9 @@ def vectorizer(): @pytest.fixture def cache(vectorizer, redis_url): cache_instance = SemanticCache( - vectorizer=vectorizer, distance_threshold=0.2, redis_url=redis_url + vectorizer=vectorizer, + distance_threshold=0.2, + redis_url=redis_url, ) yield cache_instance cache_instance._index.delete(True) # Clean up index @@ -266,9 +268,6 @@ def test_check_invalid_input(cache): with pytest.raises(ValueError): cache.check() - with pytest.raises(TypeError): - cache.check(prompt="test", return_fields="bad value") - def test_bad_connection_info(vectorizer): with pytest.raises(ConnectionError): diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index 56943447..20c2955d 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -464,8 +464,7 @@ def test_semantic_add_and_get_relevant(semantic_session): default_context = semantic_session.get_relevant("list of fruits and vegetables") assert len(default_context) == 5 # 2 pairs of prompt:response, and system assert default_context == semantic_session.get_relevant( - "list of fruits and vegetables", - distance_threshold=0.5 + "list of fruits and vegetables", distance_threshold=0.5 ) # test tool calls can also be returned diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py index e0fd4f17..17426868 100644 --- a/tests/unit/test_query_types.py +++ b/tests/unit/test_query_types.py @@ -1,3 +1,4 @@ +import pytest from redis.commands.search.query import Query from redis.commands.search.result import Result @@ -18,6 +19,7 @@ def test_count_query(): count_query = CountQuery(filter_expression) # Check properties + assert isinstance(count_query, Query) assert isinstance(count_query.query, Query) assert isinstance(count_query.params, dict) assert count_query.params == {} @@ -25,7 +27,7 @@ def test_count_query(): # Test set_filter functionality new_filter_expression = Tag("category") == "Sportswear" count_query.set_filter(new_filter_expression) - assert count_query.get_filter() == new_filter_expression + assert count_query.filter == new_filter_expression fake_result = Result([2], "") assert process_results(fake_result, count_query, "json") == 2 @@ -40,30 +42,31 @@ def test_filter_query(): # Check properties assert filter_query._return_fields == return_fields assert filter_query._num_results == 10 - assert filter_query.get_filter() == filter_expression + assert filter_query.filter == filter_expression + assert isinstance(filter_query, Query) assert isinstance(filter_query.query, Query) assert isinstance(filter_query.params, dict) assert filter_query.params == {} assert filter_query._dialect == 2 - assert filter_query._sort_by == None + assert filter_query._sortby == None assert filter_query._in_order == False # Test set_filter functionality new_filter_expression = Tag("category") == "Sportswear" filter_query.set_filter(new_filter_expression) - assert filter_query.get_filter() == new_filter_expression + assert filter_query.filter == new_filter_expression - # Test set_paging functionality - filter_query.set_paging(5, 7) - assert filter_query._first == 5 - assert filter_query._limit == 7 + # Test paging functionality + filter_query.paging(5, 7) + assert filter_query._offset == 5 + assert filter_query._num == 7 assert filter_query._num_results == 10 # Test sort_by functionality filter_query = FilterQuery( filter_expression, return_fields, num_results=10, sort_by="price" ) - assert filter_query._sort_by == "price" + assert filter_query._sortby is not None # Test in_order functionality filter_query = FilterQuery( @@ -80,25 +83,26 @@ def test_vector_query(): # Check properties assert vector_query._vector == sample_vector - assert vector_query._field == "vector_field" + assert vector_query._vector_field_name == "vector_field" assert vector_query._num_results == 10 - assert "field1" in vector_query._return_fields + assert vector_query._return_fields == ["field1", "field2", "vector_distance"] + assert isinstance(vector_query, Query) assert isinstance(vector_query.query, Query) assert isinstance(vector_query.params, dict) assert vector_query.params != {} assert vector_query._dialect == 3 - assert vector_query._sort_by == None + assert vector_query._sortby.args[0] == VectorQuery.DISTANCE_ID assert vector_query._in_order == False # Test set_filter functionality new_filter_expression = Tag("category") == "Sportswear" vector_query.set_filter(new_filter_expression) - assert vector_query.get_filter() == new_filter_expression + assert vector_query.filter == new_filter_expression - # Test set_paging functionality - vector_query.set_paging(5, 7) - assert vector_query._first == 5 - assert vector_query._limit == 7 + # Test paging functionality + vector_query.paging(5, 7) + assert vector_query._offset == 5 + assert vector_query._num == 7 assert vector_query._num_results == 10 # Test sort_by functionality @@ -110,7 +114,18 @@ def test_vector_query(): num_results=10, sort_by="field2", ) - assert vector_query._sort_by == "field2" + assert vector_query._sortby.args[0] == "field2" + + # Test in_order functionality + vector_query = VectorQuery( + sample_vector, + "vector_field", + ["field1", "field2"], + dialect=3, + num_results=10, + in_order=True, + ) + assert vector_query._in_order def test_range_query(): @@ -124,25 +139,29 @@ def test_range_query(): # Check properties assert range_query._vector == sample_vector - assert range_query._field == "vector_field" + assert range_query._vector_field_name == "vector_field" assert range_query._num_results == 10 assert range_query.distance_threshold == 0.2 assert "field1" in range_query._return_fields + assert isinstance(range_query, Query) assert isinstance(range_query.query, Query) assert isinstance(range_query.params, dict) assert range_query.params != {} - assert range_query._sort_by == None - assert range_query._sort_by == None + assert range_query._sortby.args[0] == RangeQuery.DISTANCE_ID + + # Test set_distance_threshold functionality + range_query.set_distance_threshold(0.1) + assert range_query.distance_threshold == 0.1 # Test set_filter functionality new_filter_expression = Tag("category") == "Outdoor" range_query.set_filter(new_filter_expression) - assert range_query.get_filter() == new_filter_expression + assert range_query.filter == new_filter_expression - # Test set_paging functionality - range_query.set_paging(5, 7) - assert range_query._first == 5 - assert range_query._limit == 7 + # Test paging functionality + range_query.paging(5, 7) + assert range_query._offset == 5 + assert range_query._num == 7 assert range_query._num_results == 10 # Test sort_by functionality @@ -154,4 +173,82 @@ def test_range_query(): num_results=10, sort_by="field1", ) - assert range_query._sort_by == "field1" + assert range_query._sortby.args[0] == "field1" + + # Test in_order functionality + range_query = RangeQuery( + sample_vector, + "vector_field", + ["field1"], + filter_expression, + num_results=10, + in_order=True, + ) + assert range_query._in_order + + +@pytest.mark.parametrize( + "query", + [ + CountQuery(), + FilterQuery(), + VectorQuery(vector=[1, 2, 3], vector_field_name="vector"), + RangeQuery(vector=[1, 2, 3], vector_field_name="vector"), + ], +) +def test_query_modifiers(query): + query.paging(3, 5) + assert query._offset == 3 + assert query._num == 5 + + query.dialect(4) + assert query._dialect == 4 + + query.in_order() + assert query._in_order + + query.sort_by("time") + assert query._sortby.args[0] == "time" + + query.scorer("BM25") + assert query._scorer == "BM25" + + query.timeout(20) + assert query._timeout == 20 + + query.slop(10) + assert query._slop == 10 + + query.verbatim() + assert query._verbatim + + query.no_content() + assert query._no_content + + query.no_stopwords() + assert query._no_stopwords + + query.with_scores() + assert query._with_scores + + query.limit_fields("test") + assert query._fields == ("test",) + + f = Tag("test") == "foo" + query.set_filter(f) + assert query._filter_expression == f + + # double check all other states + assert query._offset == 3 + assert query._num == 5 + assert query._dialect == 4 + assert query._in_order + assert query._sortby.args[0] == "time" + assert query._scorer == "BM25" + assert query._timeout == 20 + assert query._slop == 10 + assert query._verbatim + assert query._no_content + assert query._no_stopwords + assert query._with_scores + assert query._fields == ("test",)