From b0d8e917329e920ece7084578d19732e99db05a7 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Tue, 17 Sep 2024 10:21:42 -0700 Subject: [PATCH 01/25] adds ml_dtypes as dependency for Bfloat16 --- docs/api/schema.rst | 2 +- poetry.lock | 45 ++++++++++++++++++-- pyproject.toml | 1 + redisvl/extensions/llmcache/schema.py | 4 +- redisvl/extensions/router/schema.py | 4 +- redisvl/extensions/session_manager/schema.py | 4 +- redisvl/query/query.py | 3 ++ redisvl/schema/fields.py | 2 + 8 files changed, 54 insertions(+), 11 deletions(-) diff --git a/docs/api/schema.rst b/docs/api/schema.rst index ebe4ca8a..36245ba4 100644 --- a/docs/api/schema.rst +++ b/docs/api/schema.rst @@ -88,7 +88,7 @@ Each field type supports specific attributes that customize its behavior. Below - `dims`: Dimensionality of the vector. - `algorithm`: Indexing algorithm (`flat` or `hnsw`). -- `datatype`: Float datatype of the vector (`float32` or `float64`). +- `datatype`: Float datatype of the vector (`bfloat16`, `float16`, `float32`, `float64`). - `distance_metric`: Metric for measuring query relevance (`COSINE`, `L2`, `IP`). **HNSW Vector Field Specific Attributes**: diff --git a/poetry.lock b/poetry.lock index 24c3048f..bfeb29d5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1003,12 +1003,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -2351,6 +2351,43 @@ files = [ intel-openmp = "==2021.*" tbb = "==2021.*" +[[package]] +name = "ml-dtypes" +version = "0.4.1" +description = "" +optional = false +python-versions = ">=3.9" +files = [ + {file = "ml_dtypes-0.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1fe8b5b5e70cd67211db94b05cfd58dace592f24489b038dc6f9fe347d2e07d5"}, + {file = "ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c09a6d11d8475c2a9fd2bc0695628aec105f97cab3b3a3fb7c9660348ff7d24"}, + {file = "ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f5e8f75fa371020dd30f9196e7d73babae2abd51cf59bdd56cb4f8de7e13354"}, + {file = "ml_dtypes-0.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:15fdd922fea57e493844e5abb930b9c0bd0af217d9edd3724479fc3d7ce70e3f"}, + {file = "ml_dtypes-0.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2d55b588116a7085d6e074cf0cdb1d6fa3875c059dddc4d2c94a4cc81c23e975"}, + {file = "ml_dtypes-0.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e138a9b7a48079c900ea969341a5754019a1ad17ae27ee330f7ebf43f23877f9"}, + {file = "ml_dtypes-0.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:74c6cfb5cf78535b103fde9ea3ded8e9f16f75bc07789054edc7776abfb3d752"}, + {file = "ml_dtypes-0.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:274cc7193dd73b35fb26bef6c5d40ae3eb258359ee71cd82f6e96a8c948bdaa6"}, + {file = "ml_dtypes-0.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:827d3ca2097085cf0355f8fdf092b888890bb1b1455f52801a2d7756f056f54b"}, + {file = "ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:772426b08a6172a891274d581ce58ea2789cc8abc1c002a27223f314aaf894e7"}, + {file = "ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:126e7d679b8676d1a958f2651949fbfa182832c3cd08020d8facd94e4114f3e9"}, + {file = "ml_dtypes-0.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:df0fb650d5c582a9e72bb5bd96cfebb2cdb889d89daff621c8fbc60295eba66c"}, + {file = "ml_dtypes-0.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e35e486e97aee577d0890bc3bd9e9f9eece50c08c163304008587ec8cfe7575b"}, + {file = "ml_dtypes-0.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:560be16dc1e3bdf7c087eb727e2cf9c0e6a3d87e9f415079d2491cc419b3ebf5"}, + {file = "ml_dtypes-0.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad0b757d445a20df39035c4cdeed457ec8b60d236020d2560dbc25887533cf50"}, + {file = "ml_dtypes-0.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:ef0d7e3fece227b49b544fa69e50e607ac20948f0043e9f76b44f35f229ea450"}, + {file = "ml_dtypes-0.4.1.tar.gz", hash = "sha256:fad5f2de464fd09127e49b7fd1252b9006fb43d2edc1ff112d390c324af5ca7a"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.21.2", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">1.20", markers = "python_version < \"3.10\""}, +] + +[package.extras] +dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] + [[package]] name = "mpmath" version = "1.3.0" @@ -3426,9 +3463,9 @@ files = [ astroid = ">=3.1.0,<=3.2.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ - {version = ">=0.2", markers = "python_version < \"3.11\""}, {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=0.2", markers = "python_version < \"3.11\""}, ] isort = ">=4.2.5,<5.13.0 || >5.13.0,<6" mccabe = ">=0.6,<0.8" @@ -5523,4 +5560,4 @@ sentence-transformers = ["sentence-transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "be9b5df2ff3600823749e4d0bfffe148c6bb04f88fa287a3dfae712ade9fd06e" +content-hash = "4dbfe0e66ba3b90c5cb8746034ec5e870e07eb8207c5ba95ac700939c91ac89d" diff --git a/pyproject.toml b/pyproject.toml index 21af30ab..62902e27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ sentence-transformers = { version = ">=2.2.2", optional = true } google-cloud-aiplatform = { version = ">=1.26", optional = true } cohere = { version = ">=4.44", optional = true } mistralai = { version = ">=0.2.0", optional = true } +ml-dtypes = "^0.4.0" [tool.poetry.extras] openai = ["openai"] diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index 515b1421..12b066dc 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -105,7 +105,7 @@ def to_dict(self) -> Dict: class SemanticCacheIndexSchema(IndexSchema): @classmethod - def from_params(cls, name: str, prefix: str, vector_dims: int): + def from_params(cls, name: str, prefix: str, vector_dims: int, dtype: str = "float32"): return cls( index={"name": name, "prefix": prefix}, # type: ignore @@ -119,7 +119,7 @@ def from_params(cls, name: str, prefix: str, vector_dims: int): "type": "vector", "attrs": { "dims": vector_dims, - "datatype": "float32", + "datatype": dtype, "distance_metric": "cosine", "algorithm": "flat", }, diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py index 11b88dc6..800d0fe1 100644 --- a/redisvl/extensions/router/schema.py +++ b/redisvl/extensions/router/schema.py @@ -88,7 +88,7 @@ class SemanticRouterIndexSchema(IndexSchema): """Customized index schema for SemanticRouter.""" @classmethod - def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema": + def from_params(cls, name: str, vector_dims: int, dtype: str= "float32") -> "SemanticRouterIndexSchema": """Create an index schema based on router name and vector dimensions. Args: @@ -110,7 +110,7 @@ def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema" "algorithm": "flat", "dims": vector_dims, "distance_metric": "cosine", - "datatype": "float32", + "datatype": dtype, }, }, ], diff --git a/redisvl/extensions/session_manager/schema.py b/redisvl/extensions/session_manager/schema.py index 0e35edd2..31653701 100644 --- a/redisvl/extensions/session_manager/schema.py +++ b/redisvl/extensions/session_manager/schema.py @@ -67,7 +67,7 @@ def from_params(cls, name: str, prefix: str): class SemanticSessionIndexSchema(IndexSchema): @classmethod - def from_params(cls, name: str, prefix: str, vectorizer_dims: int): + def from_params(cls, name: str, prefix: str, vectorizer_dims: int, dtype: str = "float32"): return cls( index={"name": name, "prefix": prefix}, # type: ignore @@ -82,7 +82,7 @@ def from_params(cls, name: str, prefix: str, vectorizer_dims: int): "type": "vector", "attrs": { "dims": vectorizer_dims, - "datatype": "float32", + "datatype": dtype, "distance_metric": "cosine", "algorithm": "flat", }, diff --git a/redisvl/query/query.py b/redisvl/query/query.py index a1b3832b..ace7f03e 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, Optional, Union import numpy as np +from ml_dtypes import bfloat16 from redis.commands.search.query import Query from redisvl.query.filter import FilterExpression @@ -202,6 +203,8 @@ def query(self) -> Query: class BaseVectorQuery(BaseQuery): DTYPES = { + "bfloat16": bfloat16, + "float16": np.float16, "float32": np.float32, "float64": np.float64, } diff --git a/redisvl/schema/fields.py b/redisvl/schema/fields.py index 7dd85bea..132e785f 100644 --- a/redisvl/schema/fields.py +++ b/redisvl/schema/fields.py @@ -26,6 +26,8 @@ class VectorDistanceMetric(str, Enum): class VectorDataType(str, Enum): + BFLOAT16 = "BFLOAT16" + FLOAT16 = "FLOAT16" FLOAT32 = "FLOAT32" FLOAT64 = "FLOAT64" From 80cdd321195740dbd0fa71341ddbad1bffd2d4e7 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Wed, 18 Sep 2024 14:29:17 -0700 Subject: [PATCH 02/25] wip: tests pass with float32 default, except test_to_yaml_and_reload --- redisvl/extensions/llmcache/schema.py | 8 ++++--- redisvl/extensions/llmcache/semantic.py | 9 +++++-- redisvl/extensions/router/schema.py | 4 ++-- redisvl/extensions/router/semantic.py | 10 ++++++-- redisvl/extensions/session_manager/schema.py | 8 ++++--- .../session_manager/semantic_session.py | 4 +++- redisvl/index/index.py | 2 +- redisvl/index/storage.py | 17 +++++++++++-- redisvl/query/query.py | 14 +++-------- redisvl/redis/utils.py | 24 +++++++++++++++++-- redisvl/schema/schema.py | 4 +++- redisvl/utils/vectorize/base.py | 5 ++-- redisvl/utils/vectorize/text/azureopenai.py | 10 ++++---- redisvl/utils/vectorize/text/cohere.py | 4 ++-- redisvl/utils/vectorize/text/custom.py | 12 ++++++---- redisvl/utils/vectorize/text/huggingface.py | 4 ++-- redisvl/utils/vectorize/text/mistral.py | 10 ++++---- redisvl/utils/vectorize/text/openai.py | 10 ++++---- redisvl/utils/vectorize/text/vertexai.py | 4 ++-- tests/integration/test_flow.py | 5 +++- tests/integration/test_flow_async.py | 5 +++- tests/integration/test_query.py | 5 +++- tests/integration/test_session_manager.py | 3 +-- tests/unit/test_llmcache_schema.py | 2 +- tests/unit/test_session_schema.py | 2 +- tests/unit/test_utils.py | 24 ++++++++----------- 26 files changed, 134 insertions(+), 75 deletions(-) diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index 12b066dc..c4df4ee3 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -26,6 +26,8 @@ class CacheEntry(BaseModel): """Optional metadata stored on the cache entry""" filters: Optional[Dict[str, Any]] = Field(default=None) """Optional filter data stored on the cache entry for customizing retrieval""" + dtype: str = Field(default="float32") + """The data type for the prompt vector.""" @root_validator(pre=True) @classmethod @@ -43,7 +45,7 @@ def non_empty_metadata(cls, v): def to_dict(self) -> Dict: data = self.dict(exclude_none=True) - data["prompt_vector"] = array_to_buffer(self.prompt_vector) + data["prompt_vector"] = array_to_buffer(self.prompt_vector, self.dtype) if self.metadata is not None: data["metadata"] = serialize(self.metadata) if self.filters is not None: @@ -105,10 +107,10 @@ def to_dict(self) -> Dict: class SemanticCacheIndexSchema(IndexSchema): @classmethod - def from_params(cls, name: str, prefix: str, vector_dims: int, dtype: str = "float32"): + def from_params(cls, name: str, prefix: str, vector_dims: int, dtype: str): return cls( - index={"name": name, "prefix": prefix}, # type: ignore + index={"name": name, "prefix": prefix, "dtype": dtype.upper()}, # type: ignore fields=[ # type: ignore {"name": "prompt", "type": "text"}, {"name": "response", "type": "text"}, diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 17856196..7c16f358 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -92,7 +92,10 @@ def __init__( ] # Create semantic cache schema and index - schema = SemanticCacheIndexSchema.from_params(name, prefix, vectorizer.dims) + dtype = kwargs.get("dtype", "float32") + schema = SemanticCacheIndexSchema.from_params( + name, prefix, vectorizer.dims, dtype + ) schema = self._modify_schema(schema, filterable_fields) self._index = SearchIndex(schema=schema) @@ -235,7 +238,7 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: if not isinstance(prompt, str): raise TypeError("Prompt must be a string.") - return self._vectorizer.embed(prompt) + return self._vectorizer.embed(prompt, dtype=self.index.schema.index.dtype) def _check_vector_dims(self, vector: List[float]): """Checks the size of the provided vector and raises an error if it @@ -312,6 +315,7 @@ def check( num_results=num_results, return_score=True, filter_expression=filter_expression, + dtype=self.index.schema.index.dtype, ) cache_hits: List[Dict[Any, str]] = [] @@ -382,6 +386,7 @@ def store( prompt_vector=vector, metadata=metadata, filters=filters, + dtype=self.index.schema.index.dtype, ) # Load cache entry with TTL diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py index 800d0fe1..f171baac 100644 --- a/redisvl/extensions/router/schema.py +++ b/redisvl/extensions/router/schema.py @@ -88,7 +88,7 @@ class SemanticRouterIndexSchema(IndexSchema): """Customized index schema for SemanticRouter.""" @classmethod - def from_params(cls, name: str, vector_dims: int, dtype: str= "float32") -> "SemanticRouterIndexSchema": + def from_params(cls, name: str, vector_dims: int, dtype: str): """Create an index schema based on router name and vector dimensions. Args: @@ -99,7 +99,7 @@ def from_params(cls, name: str, vector_dims: int, dtype: str= "float32") -> "Sem SemanticRouterIndexSchema: The constructed index schema. """ return cls( - index=IndexInfo(name=name, prefix=name), + index={"name": name, "prefix": name, "dtype": dtype.upper()}, # type: ignore fields=[ # type: ignore {"name": "route_name", "type": "tag"}, {"name": "reference", "type": "text"}, diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index bab69578..ba310709 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -84,17 +84,23 @@ def __init__( vectorizer=vectorizer, routing_config=routing_config, ) - self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs) + dtype = kwargs.get("dtype", "float32") + self._initialize_index( + redis_client, redis_url, overwrite, dtype, **connection_kwargs + ) def _initialize_index( self, redis_client: Optional[Redis] = None, redis_url: str = "redis://localhost:6379", overwrite: bool = False, + dtype: str = "float32", **connection_kwargs, ): """Initialize the search index and handle Redis connection.""" - schema = SemanticRouterIndexSchema.from_params(self.name, self.vectorizer.dims) + schema = SemanticRouterIndexSchema.from_params( + self.name, self.vectorizer.dims, dtype + ) self._index = SearchIndex(schema=schema) if redis_client: diff --git a/redisvl/extensions/session_manager/schema.py b/redisvl/extensions/session_manager/schema.py index 31653701..80934582 100644 --- a/redisvl/extensions/session_manager/schema.py +++ b/redisvl/extensions/session_manager/schema.py @@ -24,6 +24,8 @@ class ChatMessage(BaseModel): """An optional identifier for a tool call associated with the message.""" vector_field: Optional[List[float]] = Field(default=None) """The vector representation of the message content.""" + dtype: Optional[str] = Field(default="float32") + """The data type for the prompt vector.""" class Config: arbitrary_types_allowed = True @@ -42,7 +44,7 @@ def to_dict(self) -> Dict: # handle optional fields if "vector_field" in data: - data["vector_field"] = array_to_buffer(data["vector_field"]) + data["vector_field"] = array_to_buffer(data["vector_field"], self.dtype) # type: ignore[arg-type] return data @@ -67,10 +69,10 @@ def from_params(cls, name: str, prefix: str): class SemanticSessionIndexSchema(IndexSchema): @classmethod - def from_params(cls, name: str, prefix: str, vectorizer_dims: int, dtype: str = "float32"): + def from_params(cls, name: str, prefix: str, vectorizer_dims: int, dtype: str): return cls( - index={"name": name, "prefix": prefix}, # type: ignore + index={"name": name, "prefix": prefix, "dtype": dtype.upper()}, # type: ignore fields=[ # type: ignore {"name": "role", "type": "tag"}, {"name": "content", "type": "text"}, diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 773f3fc5..6b83f764 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -51,6 +51,7 @@ def __init__( redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. connection_kwargs (Dict[str, Any]): The connection arguments for the redis client. Defaults to empty {}. + dtype (str): The data type for the prompt vector. Defaults to "float32". The proposed schema will support a single vector embedding constructed from either the prompt or response in a single string. @@ -66,8 +67,9 @@ def __init__( self.set_distance_threshold(distance_threshold) + dtype = kwargs.get("dtype", "float32") schema = SemanticSessionIndexSchema.from_params( - name, prefix, self._vectorizer.dims + name, prefix, self._vectorizer.dims, dtype ) self._index = SearchIndex(schema=schema) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index f5e6b4a6..76deaa4c 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -351,7 +351,7 @@ def from_existing( # Validate modules installed_modules = RedisConnectionFactory.get_modules(redis_client) - validate_modules(installed_modules, [{"name": "search", "ver": 20810}]) + validate_modules(installed_modules, [{"name": "search", "ver": 21005}]) # Fetch index info and convert to schema index_info = cls._info(name, redis_client) diff --git a/redisvl/index/storage.py b/redisvl/index/storage.py index 12ef2052..9b9aa36c 100644 --- a/redisvl/index/storage.py +++ b/redisvl/index/storage.py @@ -2,6 +2,7 @@ import uuid from typing import Any, Callable, Dict, Iterable, List, Optional +from numpy import frombuffer from pydantic.v1 import BaseModel from redis import Redis from redis.asyncio import Redis as AsyncRedis @@ -393,18 +394,30 @@ class HashStorage(BaseStorage): """Hash data type for the index""" def _validate(self, obj: Dict[str, Any]): - """Validate that the given object is a dictionary, suitable for storage - as a Redis hash. + """Validate that the given object is a dictionary suitable for storage + as a Redis hash, and the vector byte string is of correct datatype. Args: obj (Dict[str, Any]): The object to validate. Raises: TypeError: If the object is not a dictionary. + ValueError: if the vector byte string is not of correct datatype. """ if not isinstance(obj, dict): raise TypeError("Object must be a dictionary.") + try: + byte_string = obj["prompt_vector"] + vector = frombuffer(byte_string, dtype=obj["dtype"].lower()) + except ValueError: + raise ValueError( + f"Could not convert byte string to vector of type {obj['dtype']}.", + "buffer size must be a multiple of element size", + ) + except KeyError: + pass # regular hash entry with no prompt_vector or dtype needed + @staticmethod def _set(client: Redis, key: str, obj: Dict[str, Any]): """Synchronously set a hash value in Redis for the given key. diff --git a/redisvl/query/query.py b/redisvl/query/query.py index ace7f03e..831f0df9 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -1,7 +1,5 @@ from typing import Any, Dict, List, Optional, Union -import numpy as np -from ml_dtypes import bfloat16 from redis.commands.search.query import Query from redisvl.query.filter import FilterExpression @@ -202,12 +200,6 @@ def query(self) -> Query: class BaseVectorQuery(BaseQuery): - DTYPES = { - "bfloat16": bfloat16, - "float16": np.float16, - "float32": np.float32, - "float64": np.float64, - } DISTANCE_ID = "vector_distance" VECTOR_PARAM = "vector" @@ -228,7 +220,7 @@ def __init__( self.set_filter(filter_expression) self._vector = vector self._field = vector_field_name - self._dtype = dtype.lower() + self._dtype = dtype if return_score: self._return_fields.append(self.DISTANCE_ID) @@ -326,7 +318,7 @@ def params(self) -> Dict[str, Any]: if isinstance(self._vector, bytes): vector_param = self._vector else: - vector_param = array_to_buffer(self._vector, dtype=self.DTYPES[self._dtype]) + vector_param = array_to_buffer(self._vector, dtype=self._dtype) return {self.VECTOR_PARAM: vector_param} @@ -460,7 +452,7 @@ def params(self) -> Dict[str, Any]: if isinstance(self._vector, bytes): vector_param = self._vector else: - vector_param = array_to_buffer(self._vector, dtype=self.DTYPES[self._dtype]) + vector_param = array_to_buffer(self._vector, dtype=self._dtype) return { self.VECTOR_PARAM: vector_param, diff --git a/redisvl/redis/utils.py b/redisvl/redis/utils.py index a421022b..3186fcfd 100644 --- a/redisvl/redis/utils.py +++ b/redisvl/redis/utils.py @@ -2,6 +2,14 @@ from typing import Any, Dict, List import numpy as np +from ml_dtypes import bfloat16 + +VectorDataTypes = { + "BFLOAT16": bfloat16, + "FLOAT16": np.float16, + "FLOAT32": np.float32, + "FLOAT64": np.float64, +} def make_dict(values: List[Any]) -> Dict[Any, Any]: @@ -30,13 +38,25 @@ def convert_bytes(data: Any) -> Any: return data -def array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes: +def array_to_buffer(array: List[float], dtype: str) -> bytes: """Convert a list of floats into a numpy byte string.""" + try: + dtype = VectorDataTypes[dtype.upper()] + except KeyError: + raise ValueError( + f"Invalid data type: {dtype}. Supported types are: {VectorDataTypes.keys()}" + ) return np.array(array).astype(dtype).tobytes() -def buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]: +def buffer_to_array(buffer: bytes, dtype: str) -> List[float]: """Convert bytes into into a list of floats.""" + try: + dtype = VectorDataTypes[dtype.upper()] + except KeyError: + raise ValueError( + f"Invalid data type: {dtype}. Supported types are: {VectorDataTypes.keys()}" + ) return np.frombuffer(buffer, dtype=dtype).tolist() diff --git a/redisvl/schema/schema.py b/redisvl/schema/schema.py index 7f3db845..64938651 100644 --- a/redisvl/schema/schema.py +++ b/redisvl/schema/schema.py @@ -7,7 +7,7 @@ from pydantic.v1 import BaseModel, Field, root_validator from redis.commands.search.field import Field as RedisField -from redisvl.schema.fields import BaseField, FieldFactory +from redisvl.schema.fields import BaseField, FieldFactory, VectorDataType from redisvl.utils.log import get_logger from redisvl.utils.utils import model_to_dict @@ -63,6 +63,8 @@ class IndexInfo(BaseModel): """The separator character used in designing Redis keys.""" storage_type: StorageType = StorageType.HASH """The storage type used in Redis (e.g., 'hash' or 'json').""" + dtype: VectorDataType = VectorDataType.FLOAT32 + """The data type for the vector field in the index.""" class IndexSchema(BaseModel): diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index 3ea2dccd..cb6703a5 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -81,7 +81,8 @@ def batchify(self, seq: list, size: int, preprocess: Optional[Callable] = None): else: yield seq[pos : pos + size] - def _process_embedding(self, embedding: List[float], as_buffer: bool): + def _process_embedding(self, embedding: List[float], as_buffer: bool, **kwargs): + dtype = kwargs.get("dtype", "float32") if as_buffer: - return array_to_buffer(embedding) + return array_to_buffer(embedding, dtype) return embedding diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index 734fef5b..3129c0b0 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -194,7 +194,8 @@ def embed_many( for batch in self.batchify(texts, batch_size, preprocess): response = self._client.embeddings.create(input=batch, model=self.model) embeddings += [ - self._process_embedding(r.embedding, as_buffer) for r in response.data + self._process_embedding(r.embedding, as_buffer, **kwargs) + for r in response.data ] return embeddings @@ -231,7 +232,7 @@ def embed( if preprocess: text = preprocess(text) result = self._client.embeddings.create(input=[text], model=self.model) - return self._process_embedding(result.data[0].embedding, as_buffer) + return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs) @retry( wait=wait_random_exponential(min=1, max=60), @@ -274,7 +275,8 @@ async def aembed_many( input=batch, model=self.model ) embeddings += [ - self._process_embedding(r.embedding, as_buffer) for r in response.data + self._process_embedding(r.embedding, as_buffer, **kwargs) + for r in response.data ] return embeddings @@ -311,7 +313,7 @@ async def aembed( if preprocess: text = preprocess(text) result = await self._aclient.embeddings.create(input=[text], model=self.model) - return self._process_embedding(result.data[0].embedding, as_buffer) + return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs) @property def type(self) -> str: diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py index 47275d40..8270f09b 100644 --- a/redisvl/utils/vectorize/text/cohere.py +++ b/redisvl/utils/vectorize/text/cohere.py @@ -160,7 +160,7 @@ def embed( embedding = self._client.embed( texts=[text], model=self.model, input_type=input_type ).embeddings[0] - return self._process_embedding(embedding, as_buffer) + return self._process_embedding(embedding, as_buffer, **kwargs) @retry( wait=wait_random_exponential(min=1, max=60), @@ -230,7 +230,7 @@ def embed_many( texts=batch, model=self.model, input_type=input_type ) embeddings += [ - self._process_embedding(embedding, as_buffer) + self._process_embedding(embedding, as_buffer, **kwargs) for embedding in response.embeddings ] return embeddings diff --git a/redisvl/utils/vectorize/text/custom.py b/redisvl/utils/vectorize/text/custom.py index bf0eec4a..8dc42c12 100644 --- a/redisvl/utils/vectorize/text/custom.py +++ b/redisvl/utils/vectorize/text/custom.py @@ -174,7 +174,7 @@ def embed( text = preprocess(text) else: result = self._embed_func(text, **kwargs) - return self._process_embedding(result, as_buffer) + return self._process_embedding(result, as_buffer, **kwargs) def embed_many( self, @@ -213,7 +213,9 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): results = self._embed_many_func(batch, **kwargs) - embeddings += [self._process_embedding(r, as_buffer) for r in results] + embeddings += [ + self._process_embedding(r, as_buffer, **kwargs) for r in results + ] return embeddings async def aembed( @@ -249,7 +251,7 @@ async def aembed( text = preprocess(text) else: result = await self._aembed_func(text, **kwargs) - return self._process_embedding(result, as_buffer) + return self._process_embedding(result, as_buffer, **kwargs) async def aembed_many( self, @@ -288,7 +290,9 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): results = await self._aembed_many_func(batch, **kwargs) - embeddings += [self._process_embedding(r, as_buffer) for r in results] + embeddings += [ + self._process_embedding(r, as_buffer, **kwargs) for r in results + ] return embeddings @property diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index d5e255c9..f867dfa9 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -100,7 +100,7 @@ def embed( if preprocess: text = preprocess(text) embedding = self._client.encode([text])[0] - return self._process_embedding(embedding.tolist(), as_buffer) + return self._process_embedding(embedding.tolist(), as_buffer, **kwargs) def embed_many( self, @@ -138,7 +138,7 @@ def embed_many( batch_embeddings = self._client.encode(batch) embeddings.extend( [ - self._process_embedding(embedding.tolist(), as_buffer) + self._process_embedding(embedding.tolist(), as_buffer, **kwargs) for embedding in batch_embeddings ] ) diff --git a/redisvl/utils/vectorize/text/mistral.py b/redisvl/utils/vectorize/text/mistral.py index 8776ef3d..7d4f00f5 100644 --- a/redisvl/utils/vectorize/text/mistral.py +++ b/redisvl/utils/vectorize/text/mistral.py @@ -144,7 +144,8 @@ def embed_many( for batch in self.batchify(texts, batch_size, preprocess): response = self._client.embeddings(model=self.model, input=batch) embeddings += [ - self._process_embedding(r.embedding, as_buffer) for r in response.data + self._process_embedding(r.embedding, as_buffer, **kwargs) + for r in response.data ] return embeddings @@ -181,7 +182,7 @@ def embed( if preprocess: text = preprocess(text) result = self._client.embeddings(model=self.model, input=[text]) - return self._process_embedding(result.data[0].embedding, as_buffer) + return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs) @retry( wait=wait_random_exponential(min=1, max=60), @@ -222,7 +223,8 @@ async def aembed_many( for batch in self.batchify(texts, batch_size, preprocess): response = await self._aclient.embeddings(model=self.model, input=batch) embeddings += [ - self._process_embedding(r.embedding, as_buffer) for r in response.data + self._process_embedding(r.embedding, as_buffer, **kwargs) + for r in response.data ] return embeddings @@ -259,7 +261,7 @@ async def aembed( if preprocess: text = preprocess(text) result = await self._aclient.embeddings(model=self.model, input=[text]) - return self._process_embedding(result.data[0].embedding, as_buffer) + return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs) @property def type(self) -> str: diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index 5921bda8..ae5d19dc 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -148,7 +148,8 @@ def embed_many( for batch in self.batchify(texts, batch_size, preprocess): response = self._client.embeddings.create(input=batch, model=self.model) embeddings += [ - self._process_embedding(r.embedding, as_buffer) for r in response.data + self._process_embedding(r.embedding, as_buffer, **kwargs) + for r in response.data ] return embeddings @@ -185,7 +186,7 @@ def embed( if preprocess: text = preprocess(text) result = self._client.embeddings.create(input=[text], model=self.model) - return self._process_embedding(result.data[0].embedding, as_buffer) + return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs) @retry( wait=wait_random_exponential(min=1, max=60), @@ -228,7 +229,8 @@ async def aembed_many( input=batch, model=self.model ) embeddings += [ - self._process_embedding(r.embedding, as_buffer) for r in response.data + self._process_embedding(r.embedding, as_buffer, **kwargs) + for r in response.data ] return embeddings @@ -265,7 +267,7 @@ async def aembed( if preprocess: text = preprocess(text) result = await self._aclient.embeddings.create(input=[text], model=self.model) - return self._process_embedding(result.data[0].embedding, as_buffer) + return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs) @property def type(self) -> str: diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index b7248003..c69b0531 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -155,7 +155,7 @@ def embed_many( for batch in self.batchify(texts, batch_size, preprocess): response = self._client.get_embeddings(batch) embeddings += [ - self._process_embedding(r.values, as_buffer) for r in response + self._process_embedding(r.values, as_buffer, **kwargs) for r in response ] return embeddings @@ -192,7 +192,7 @@ def embed( if preprocess: text = preprocess(text) result = self._client.get_embeddings([text]) - return self._process_embedding(result[0].values, as_buffer) + return self._process_embedding(result[0].values, as_buffer, **kwargs) async def aembed_many( self, diff --git a/tests/integration/test_flow.py b/tests/integration/test_flow.py index 538b02d3..b448a636 100644 --- a/tests/integration/test_flow.py +++ b/tests/integration/test_flow.py @@ -51,7 +51,10 @@ def test_simple(client, schema, sample_data): # Prepare and load the data based on storage type def hash_preprocess(item: dict) -> dict: - return {**item, "user_embedding": array_to_buffer(item["user_embedding"])} + return { + **item, + "user_embedding": array_to_buffer(item["user_embedding"], "float32"), + } if index.storage_type == StorageType.HASH: index.load(sample_data, preprocess=hash_preprocess, id_field="user") diff --git a/tests/integration/test_flow_async.py b/tests/integration/test_flow_async.py index 3557ded8..fbfa7d22 100644 --- a/tests/integration/test_flow_async.py +++ b/tests/integration/test_flow_async.py @@ -55,7 +55,10 @@ async def test_simple(async_client, schema, sample_data): # Prepare and load the data based on storage type async def hash_preprocess(item: dict) -> dict: - return {**item, "user_embedding": array_to_buffer(item["user_embedding"])} + return { + **item, + "user_embedding": array_to_buffer(item["user_embedding"], "float32"), + } if index.storage_type == StorageType.HASH: await index.load(sample_data, preprocess=hash_preprocess, id_field="user") diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index df348f83..752d5fe7 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -103,7 +103,10 @@ def index(sample_data, redis_url): # Prepare and load the data def hash_preprocess(item: dict) -> dict: - return {**item, "user_embedding": array_to_buffer(item["user_embedding"])} + return { + **item, + "user_embedding": array_to_buffer(item["user_embedding"], "float32"), + } index.load(sample_data, preprocess=hash_preprocess) diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index 56943447..20c2955d 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -464,8 +464,7 @@ def test_semantic_add_and_get_relevant(semantic_session): default_context = semantic_session.get_relevant("list of fruits and vegetables") assert len(default_context) == 5 # 2 pairs of prompt:response, and system assert default_context == semantic_session.get_relevant( - "list of fruits and vegetables", - distance_threshold=0.5 + "list of fruits and vegetables", distance_threshold=0.5 ) # test tool calls can also be returned diff --git a/tests/unit/test_llmcache_schema.py b/tests/unit/test_llmcache_schema.py index e3961e6b..c9ece92f 100644 --- a/tests/unit/test_llmcache_schema.py +++ b/tests/unit/test_llmcache_schema.py @@ -50,7 +50,7 @@ def test_cache_entry_to_dict(): result = entry.to_dict() assert result["entry_id"] == hashify("What is AI?") assert result["metadata"] == json.dumps({"author": "John"}) - assert result["prompt_vector"] == array_to_buffer([0.1, 0.2, 0.3]) + assert result["prompt_vector"] == array_to_buffer([0.1, 0.2, 0.3], "float32") assert result["category"] == "technology" assert "filters" not in result diff --git a/tests/unit/test_session_schema.py b/tests/unit/test_session_schema.py index b25f6564..3a7e228b 100644 --- a/tests/unit/test_session_schema.py +++ b/tests/unit/test_session_schema.py @@ -103,7 +103,7 @@ def test_chat_message_to_dict(): assert data["content"] == content assert data["session_tag"] == session_tag assert data["timestamp"] == timestamp - assert data["vector_field"] == array_to_buffer(vector_field) + assert data["vector_field"] == array_to_buffer(vector_field, "float32") def test_chat_message_missing_fields(): diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index ca535c5a..d4ffaaaf 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from ml_dtypes import bfloat16 from redisvl.redis.utils import ( array_to_buffer, @@ -50,27 +51,22 @@ def test_simple_byte_buffer_to_floats(): """Test conversion of a simple byte buffer into floats""" buffer = np.array([1.0, 2.0, 3.0], dtype=np.float32).tobytes() expected = [1.0, 2.0, 3.0] - assert buffer_to_array(buffer, dtype=np.float32) == expected + assert buffer_to_array(buffer, dtype="float32") == expected -def test_different_data_types(): +def test_converting_different_data_types(): """Test conversion with different data types""" - # Integer test - buffer = np.array([1, 2, 3], dtype=np.int32).tobytes() - expected = [1, 2, 3] - assert buffer_to_array(buffer, dtype=np.int32) == expected - # Float64 test buffer = np.array([1.0, 2.0, 3.0], dtype=np.float64).tobytes() expected = [1.0, 2.0, 3.0] - assert buffer_to_array(buffer, dtype=np.float64) == expected + assert buffer_to_array(buffer, dtype="float64") == expected def test_empty_byte_buffer(): """Test conversion of an empty byte buffer""" buffer = b"" expected = [] - assert buffer_to_array(buffer, dtype=np.float32) == expected + assert buffer_to_array(buffer, dtype="float32") == expected def test_plain_bytes_to_string(): @@ -119,7 +115,7 @@ def test_simple_list_to_bytes_default_dtype(): """Test conversion of a simple list of floats to bytes using the default dtype""" array = [1.0, 2.0, 3.0] expected = np.array(array, dtype=np.float32).tobytes() - assert array_to_buffer(array) == expected + assert array_to_buffer(array, "float32") == expected def test_list_to_bytes_non_default_dtype(): @@ -127,17 +123,17 @@ def test_list_to_bytes_non_default_dtype(): array = [1.0, 2.0, 3.0] dtype = np.float64 expected = np.array(array, dtype=dtype).tobytes() - assert array_to_buffer(array, dtype=dtype) == expected + assert array_to_buffer(array, dtype="float64") == expected def test_empty_list_to_bytes(): """Test conversion of an empty list""" array = [] expected = np.array(array, dtype=np.float32).tobytes() - assert array_to_buffer(array) == expected + assert array_to_buffer(array, dtype="float32") == expected -@pytest.mark.parametrize("dtype", [np.int32, np.float64]) +@pytest.mark.parametrize("dtype", ["float64", "float32", "float16", "bfloat16"]) def test_conversion_with_various_dtypes(dtype): """Test conversion of a list of floats to bytes with various dtypes""" array = [1.0, -2.0, 3.5] @@ -148,5 +144,5 @@ def test_conversion_with_various_dtypes(dtype): def test_conversion_with_invalid_floats(): """Test conversion with invalid float values (numpy should handle them)""" array = [float("inf"), float("-inf"), float("nan")] - result = array_to_buffer(array) + result = array_to_buffer(array, "float16") assert len(result) > 0 # Simple check to ensure it returns anything From 17d99318db3ac1eb8a30a7ed8a03c3255b1af5ff Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Wed, 18 Sep 2024 17:01:27 -0700 Subject: [PATCH 03/25] wip: removes dtype from IndexInfo and reads it from field attrs --- redisvl/extensions/llmcache/schema.py | 4 ++-- redisvl/extensions/llmcache/semantic.py | 7 ++++--- redisvl/extensions/router/schema.py | 2 +- redisvl/extensions/session_manager/schema.py | 2 +- redisvl/schema/schema.py | 2 -- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index c4df4ee3..d2caf966 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -26,7 +26,7 @@ class CacheEntry(BaseModel): """Optional metadata stored on the cache entry""" filters: Optional[Dict[str, Any]] = Field(default=None) """Optional filter data stored on the cache entry for customizing retrieval""" - dtype: str = Field(default="float32") + dtype: str = Field(default="float32") ### TODO don't have a default here """The data type for the prompt vector.""" @root_validator(pre=True) @@ -110,7 +110,7 @@ class SemanticCacheIndexSchema(IndexSchema): def from_params(cls, name: str, prefix: str, vector_dims: int, dtype: str): return cls( - index={"name": name, "prefix": prefix, "dtype": dtype.upper()}, # type: ignore + index={"name": name, "prefix": prefix}, # type: ignore fields=[ # type: ignore {"name": "prompt", "type": "text"}, {"name": "response", "type": "text"}, diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 7c16f358..a88bfc6c 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -238,7 +238,8 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: if not isinstance(prompt, str): raise TypeError("Prompt must be a string.") - return self._vectorizer.embed(prompt, dtype=self.index.schema.index.dtype) + dtype = self.index.schema.fields[self.vector_field_name].attrs.datatype # type: ignore[union-attr] + return self._vectorizer.embed(prompt, dtype=dtype) def _check_vector_dims(self, vector: List[float]): """Checks the size of the provided vector and raises an error if it @@ -315,7 +316,7 @@ def check( num_results=num_results, return_score=True, filter_expression=filter_expression, - dtype=self.index.schema.index.dtype, + dtype=self.index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] ) cache_hits: List[Dict[Any, str]] = [] @@ -386,7 +387,7 @@ def store( prompt_vector=vector, metadata=metadata, filters=filters, - dtype=self.index.schema.index.dtype, + dtype=self.index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] ) # Load cache entry with TTL diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py index f171baac..e4720ea6 100644 --- a/redisvl/extensions/router/schema.py +++ b/redisvl/extensions/router/schema.py @@ -99,7 +99,7 @@ def from_params(cls, name: str, vector_dims: int, dtype: str): SemanticRouterIndexSchema: The constructed index schema. """ return cls( - index={"name": name, "prefix": name, "dtype": dtype.upper()}, # type: ignore + index={"name": name, "prefix": name}, # type: ignore fields=[ # type: ignore {"name": "route_name", "type": "tag"}, {"name": "reference", "type": "text"}, diff --git a/redisvl/extensions/session_manager/schema.py b/redisvl/extensions/session_manager/schema.py index 80934582..1d8aff9f 100644 --- a/redisvl/extensions/session_manager/schema.py +++ b/redisvl/extensions/session_manager/schema.py @@ -72,7 +72,7 @@ class SemanticSessionIndexSchema(IndexSchema): def from_params(cls, name: str, prefix: str, vectorizer_dims: int, dtype: str): return cls( - index={"name": name, "prefix": prefix, "dtype": dtype.upper()}, # type: ignore + index={"name": name, "prefix": prefix}, # type: ignore fields=[ # type: ignore {"name": "role", "type": "tag"}, {"name": "content", "type": "text"}, diff --git a/redisvl/schema/schema.py b/redisvl/schema/schema.py index 64938651..fa6501d8 100644 --- a/redisvl/schema/schema.py +++ b/redisvl/schema/schema.py @@ -63,8 +63,6 @@ class IndexInfo(BaseModel): """The separator character used in designing Redis keys.""" storage_type: StorageType = StorageType.HASH """The storage type used in Redis (e.g., 'hash' or 'json').""" - dtype: VectorDataType = VectorDataType.FLOAT32 - """The data type for the vector field in the index.""" class IndexSchema(BaseModel): From 95b1ec09102deff761166737652e766c9a300cb2 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Thu, 19 Sep 2024 13:49:11 -0700 Subject: [PATCH 04/25] adds session manager schema checks --- redisvl/extensions/llmcache/schema.py | 2 +- .../session_manager/semantic_session.py | 18 ++++++++++- tests/integration/test_llmcache.py | 27 ++++++++++++++++ tests/integration/test_session_manager.py | 31 +++++++++++++++++-- 4 files changed, 74 insertions(+), 4 deletions(-) diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index d2caf966..311fe1ce 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -26,7 +26,7 @@ class CacheEntry(BaseModel): """Optional metadata stored on the cache entry""" filters: Optional[Dict[str, Any]] = Field(default=None) """Optional filter data stored on the cache entry for customizing retrieval""" - dtype: str = Field(default="float32") ### TODO don't have a default here + dtype: str = Field(default="float32") ### TODO don't have a default here? """The data type for the prompt vector.""" @root_validator(pre=True) diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 6b83f764..1e74c946 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -27,6 +27,7 @@ def __init__( redis_client: Optional[Redis] = None, redis_url: str = "redis://localhost:6379", connection_kwargs: Dict[str, Any] = {}, + overwrite: bool = False, **kwargs, ): """Initialize session memory with index @@ -80,7 +81,17 @@ def __init__( elif redis_url: self._index.connect(redis_url=redis_url, **connection_kwargs) - self._index.create(overwrite=False) + # Check for existing session index + if not overwrite and self._index.exists(): + existing_index = SearchIndex.from_existing( + name, redis_client=self._index.client + ) + if existing_index.schema != self._index.schema: + raise ValueError( + f"Existing index {name} schema does not match the user provided schema for the semantic session. " + "If you wish to overwrite the index schema, set overwrite=True during initialization." + ) + self._index.create(overwrite=overwrite, drop=False) self._default_session_filter = Tag(self.session_field_name) == self._session_tag @@ -191,6 +202,9 @@ def get_relevant( else self._default_session_filter ) + dtype = self._index.schema.fields[self.vector_field_name].attrs.datatype # type: ignore[union-attr] + print("queryying dtype ", dtype, "for session ", prompt) + query = RangeQuery( vector=self._vectorizer.embed(prompt), vector_field_name=self.vector_field_name, @@ -199,6 +213,7 @@ def get_relevant( num_results=top_k, return_score=True, filter_expression=session_filter, + dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] ) messages = self._index.query(query) @@ -325,6 +340,7 @@ def add_messages( content=message[self.content_field_name], session_tag=session_tag, vector_field=content_vector, + dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] ) if self.tool_field_name in message: diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 2263b745..cde617a6 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -555,3 +555,30 @@ def test_index_updating(redis_url): filter_expression=tag_filter, ) assert len(response) == 1 + + +def test_create_cache_with_different_vector_types(): + bfloat_cache = SemanticCache(name="bfloat_cache", dtype="bfloat16") + bfloat_cache.store("bfloat16 prompt", "bfloat16 response") + + float16_cache = SemanticCache(name="float16_cache", dtype="float16") + float16_cache.store("float16 prompt", "float16 response") + + float32_cache = SemanticCache(name="float32_cache", dtype="float32") + float32_cache.store("float32 prompt", "float32 response") + + float64_cache = SemanticCache(name="float64_cache", dtype="float64") + float64_cache.store("float64 prompt", "float64 response") + + for cache in [bfloat_cache, float16_cache, float32_cache, float64_cache]: + cache.set_threshold(0.6) + assert len(cache.check("float prompt", num_results=5)) == 1 + + +def test_bad_dtype_connecting_to_existing_cache(): + cache1 = SemanticCache(name="float64_cache", dtype="float64") + + same_type = SemanticCache(name="float64_cache", dtype="float64") + + with pytest.raises(ValueError): + bad_type = SemanticCache(name="float64_cache", dtype="float16") diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index 20c2955d..ba482402 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -16,7 +16,7 @@ def standard_session(app_name, client): @pytest.fixture def semantic_session(app_name, client): - session = SemanticSessionManager(app_name, redis_client=client) + session = SemanticSessionManager(app_name, redis_client=client, overwrite=True) yield session session.clear() session.delete() @@ -284,7 +284,7 @@ def test_standard_clear(standard_session): # test semantic session manager def test_semantic_specify_client(client): session = SemanticSessionManager( - name="test_app", session_tag="abc", redis_client=client + name="test_app", session_tag="abc", redis_client=client, overwrite=True ) assert isinstance(session._index.client, type(client)) @@ -536,3 +536,30 @@ def test_semantic_drop(semantic_session): {"role": "llm", "content": "third response"}, {"role": "user", "content": "fourth prompt"}, ] + + +def test_different_vector_dtypes(): + bfloat_sess = SemanticSessionManager(name="bfloat_session", dtype="bfloat16") + bfloat_sess.add_message({"role": "user", "content": "bfloat message"}) + + float16_sess = SemanticSessionManager(name="float16_session", dtype="float16") + float16_sess.add_message({"role": "user", "content": "float16 message"}) + + float32_sess = SemanticSessionManager(name="float32_session", dtype="float32") + float32_sess.add_message({"role": "user", "content": "float32 message"}) + + float64_sess = SemanticSessionManager(name="float64_session", dtype="float64") + float64_sess.add_message({"role": "user", "content": "float64 message"}) + + for sess in [bfloat_sess, float16_sess, float32_sess, float64_sess]: + sess.set_distance_threshold(0.7) + assert len(sess.get_relevant("float message")) == 1 + + +def test_bad_dtype_connecting_to_exiting_session(): + session = SemanticSessionManager(name="float64 session", dtype="float64") + + same_type = SemanticSessionManager(name="float64 session", dtype="float64") + + with pytest.raises(ValueError): + bad_type = SemanticSessionManager(name="float64 session", dtype="float16") From a65eca7008911ad6451306441f0f6df66d3ed675 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Thu, 19 Sep 2024 16:16:48 -0700 Subject: [PATCH 05/25] cleans up session manager vector typing --- redisvl/extensions/session_manager/semantic_session.py | 4 ---- redisvl/utils/vectorize/base.py | 7 +++++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 1e74c946..3f956b61 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -202,9 +202,6 @@ def get_relevant( else self._default_session_filter ) - dtype = self._index.schema.fields[self.vector_field_name].attrs.datatype # type: ignore[union-attr] - print("queryying dtype ", dtype, "for session ", prompt) - query = RangeQuery( vector=self._vectorizer.embed(prompt), vector_field_name=self.vector_field_name, @@ -329,7 +326,6 @@ def add_messages( for message in messages: content_vector = self._vectorizer.embed(message[self.content_field_name]) - validate_vector_dims( len(content_vector), self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index cb6703a5..df87b190 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -82,7 +82,10 @@ def batchify(self, seq: list, size: int, preprocess: Optional[Callable] = None): yield seq[pos : pos + size] def _process_embedding(self, embedding: List[float], as_buffer: bool, **kwargs): - dtype = kwargs.get("dtype", "float32") if as_buffer: - return array_to_buffer(embedding, dtype) + if "dtype" not in kwargs: + raise RuntimeError( + "dtype is required if converting from float to byte string." + ) + return array_to_buffer(embedding, kwargs["dtype"]) return embedding From 947ccaf0142e136bdc3f3d063dd424bfe9c10549 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Thu, 19 Sep 2024 16:42:45 -0700 Subject: [PATCH 06/25] updates router to specify vector dtype --- redisvl/extensions/router/schema.py | 2 +- redisvl/extensions/router/semantic.py | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py index e4720ea6..69c14f10 100644 --- a/redisvl/extensions/router/schema.py +++ b/redisvl/extensions/router/schema.py @@ -3,7 +3,7 @@ from pydantic.v1 import BaseModel, Field, validator -from redisvl.schema import IndexInfo, IndexSchema +from redisvl.schema import IndexSchema class Route(BaseModel): diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index ba310709..aa87db72 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -40,6 +40,7 @@ class SemanticRouter(BaseModel): """The vectorizer used to embed route references.""" routing_config: RoutingConfig = Field(default_factory=RoutingConfig) """Configuration for routing behavior.""" + vector_field_name: str = "vector" _index: SearchIndex = PrivateAttr() @@ -108,8 +109,18 @@ def _initialize_index( elif redis_url: self._index.connect(redis_url=redis_url, **connection_kwargs) + # Check for existing session index existed = self._index.exists() - self._index.create(overwrite=overwrite) + if not overwrite and existed: + existing_index = SearchIndex.from_existing( + self.name, redis_client=self._index.client + ) + if existing_index.schema != self._index.schema: + raise ValueError( + f"Existing index {self.name} schema does not match the user provided schema for the semantic router. " + "If you wish to overwrite the index schema, set overwrite=True during initialization." + ) + self._index.create(overwrite=overwrite, drop=False) if not existed or overwrite: # write the routes to Redis @@ -158,7 +169,9 @@ def _add_routes(self, routes: List[Route]): for route in routes: # embed route references as a single batch reference_vectors = self.vectorizer.embed_many( - [reference for reference in route.references], as_buffer=True + [reference for reference in route.references], + as_buffer=True, + dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] ) # set route references for i, reference in enumerate(route.references): @@ -235,6 +248,7 @@ def _classify_route( vector_field_name="vector", distance_threshold=distance_threshold, return_fields=["route_name"], + dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] ) aggregate_request = self._build_aggregate_request( @@ -287,6 +301,7 @@ def _classify_multi_route( vector_field_name="vector", distance_threshold=distance_threshold, return_fields=["route_name"], + dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] ) aggregate_request = self._build_aggregate_request( vector_range_query, aggregation_method, max_k From 196ca0ad570bba3c628813f74da8724a14ec7291 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Thu, 19 Sep 2024 16:58:30 -0700 Subject: [PATCH 07/25] makes dtype required for CacheEntry --- redisvl/extensions/llmcache/schema.py | 2 +- tests/integration/test_semantic_router.py | 50 +++++++++++++++++++++++ tests/unit/test_llmcache_schema.py | 5 +++ 3 files changed, 56 insertions(+), 1 deletion(-) diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index 311fe1ce..e90c8bf6 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -26,7 +26,7 @@ class CacheEntry(BaseModel): """Optional metadata stored on the cache entry""" filters: Optional[Dict[str, Any]] = Field(default=None) """Optional filter data stored on the cache entry for customizing retrieval""" - dtype: str = Field(default="float32") ### TODO don't have a default here? + dtype: str """The data type for the prompt vector.""" @root_validator(pre=True) diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index b2a7c716..12487437 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -237,3 +237,53 @@ def test_bad_connection_info(routes): redis_url="redis://localhost:6389", # bad connection url overwrite=False, ) + + +def test_different_vector_dtypes(routes): + bfloat_router = SemanticRouter( + name="bfloat_router", + routes=routes, + dtype="bfloat16", + ) + + float16_router = SemanticRouter( + name="float16_router", + routes=routes, + dtype="float16", + ) + + float32_router = SemanticRouter( + name="float32_router", + routes=routes, + dtype="float32", + ) + + float64_router = SemanticRouter( + name="float64_router", + routes=routes, + dtype="float64", + ) + + for router in [bfloat_router, float16_router, float32_router, float64_router]: + assert len(router.route_many("hello", max_k=5)) == 1 + + +def test_bad_dtype_connecting_to_exiting_router(routes): + router = SemanticRouter( + name="float64 router", + routes=routes, + dtype="float64", + ) + + same_type = SemanticRouter( + name="float64 router", + routes=routes, + dtype="float64", + ) + + with pytest.raises(ValueError): + bad_type = SemanticRouter( + name="float64 router", + routes=routes, + dtype="float16", + ) diff --git a/tests/unit/test_llmcache_schema.py b/tests/unit/test_llmcache_schema.py index c9ece92f..7c1755d2 100644 --- a/tests/unit/test_llmcache_schema.py +++ b/tests/unit/test_llmcache_schema.py @@ -12,6 +12,7 @@ def test_valid_cache_entry_creation(): prompt="What is AI?", response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], + dtype="float16", ) assert entry.entry_id == hashify("What is AI?") assert entry.prompt == "What is AI?" @@ -25,6 +26,7 @@ def test_cache_entry_with_given_entry_id(): prompt="What is AI?", response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], + dtype="float16", ) assert entry.entry_id == "custom_id" @@ -36,6 +38,7 @@ def test_cache_entry_with_invalid_metadata(): response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], metadata="invalid_metadata", + dtype="float64", ) @@ -46,6 +49,7 @@ def test_cache_entry_to_dict(): prompt_vector=[0.1, 0.2, 0.3], metadata={"author": "John"}, filters={"category": "technology"}, + dtype="float32", ) result = entry.to_dict() assert result["entry_id"] == hashify("What is AI?") @@ -108,6 +112,7 @@ def test_cache_entry_with_empty_optional_fields(): prompt="What is AI?", response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], + dtype="bfloat16", ) result = entry.to_dict() assert "metadata" not in result From 2a2d5b440bca372c2d0b20ec7999688233983435 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Fri, 20 Sep 2024 10:23:36 -0700 Subject: [PATCH 08/25] formatting --- redisvl/query/query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 1b8541bc..856a2572 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -259,7 +259,7 @@ def params(self) -> Dict[str, Any]: if isinstance(self._vector, bytes): vector = self._vector else: - vector= array_to_buffer(self._vector, dtype=self._dtype) + vector = array_to_buffer(self._vector, dtype=self._dtype) return {self.VECTOR_PARAM: vector} From 9b0d86a508ee813518ced748c67339f06520c3da Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Fri, 20 Sep 2024 11:42:15 -0700 Subject: [PATCH 09/25] addressing PR comments --- pyproject.toml | 2 +- redisvl/index/index.py | 2 +- redisvl/index/storage.py | 11 ----------- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index be1dbf3b..027652d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,12 +26,12 @@ redis = ">=5.0.0" pydantic = { version = ">=2,<3" } tenacity = ">=8.2.2" tabulate = { version = ">=0.9.0,<1" } +ml-dtypes = "^0.4.0" openai = { version = ">=1.13.0", optional = true } sentence-transformers = { version = ">=2.2.2", optional = true } google-cloud-aiplatform = { version = ">=1.26", optional = true } cohere = { version = ">=4.44", optional = true } mistralai = { version = ">=0.2.0", optional = true } -ml-dtypes = "^0.4.0" [tool.poetry.extras] openai = ["openai"] diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 7d71ed7c..b4adb6b3 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -354,7 +354,7 @@ def from_existing( # Validate modules installed_modules = RedisConnectionFactory.get_modules(redis_client) - validate_modules(installed_modules, [{"name": "search", "ver": 21005}]) + validate_modules(installed_modules, [{"name": "search", "ver": 20810}]) # Fetch index info and convert to schema index_info = cls._info(name, redis_client) diff --git a/redisvl/index/storage.py b/redisvl/index/storage.py index 9b9aa36c..ef2c50d9 100644 --- a/redisvl/index/storage.py +++ b/redisvl/index/storage.py @@ -407,17 +407,6 @@ def _validate(self, obj: Dict[str, Any]): if not isinstance(obj, dict): raise TypeError("Object must be a dictionary.") - try: - byte_string = obj["prompt_vector"] - vector = frombuffer(byte_string, dtype=obj["dtype"].lower()) - except ValueError: - raise ValueError( - f"Could not convert byte string to vector of type {obj['dtype']}.", - "buffer size must be a multiple of element size", - ) - except KeyError: - pass # regular hash entry with no prompt_vector or dtype needed - @staticmethod def _set(client: Redis, key: str, obj: Dict[str, Any]): """Synchronously set a hash value in Redis for the given key. From 0b98281d301d155c816edf8ea7d2f61a747b2f9b Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Sat, 21 Sep 2024 18:29:56 -0700 Subject: [PATCH 10/25] changes dtype arg to string in notebook --- docs/user_guide/hash_vs_json_05.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user_guide/hash_vs_json_05.ipynb b/docs/user_guide/hash_vs_json_05.ipynb index a046a963..9cb0092b 100644 --- a/docs/user_guide/hash_vs_json_05.ipynb +++ b/docs/user_guide/hash_vs_json_05.ipynb @@ -429,7 +429,7 @@ "json_data = data.copy()\n", "\n", "for d in json_data:\n", - " d['user_embedding'] = buffer_to_array(d['user_embedding'], dtype=np.float32)" + " d['user_embedding'] = buffer_to_array(d['user_embedding'], dtype='float32')" ] }, { From d289a44cb619714583cca67904f3786296c856fb Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Sat, 21 Sep 2024 18:40:11 -0700 Subject: [PATCH 11/25] specifies vector dtype when creating byte vectors in notebooks --- docs/examples/openai_qna.ipynb | 4 ++-- docs/user_guide/vectorizers_04.ipynb | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/examples/openai_qna.ipynb b/docs/examples/openai_qna.ipynb index 614ed3e3..a1e59490 100644 --- a/docs/examples/openai_qna.ipynb +++ b/docs/examples/openai_qna.ipynb @@ -579,7 +579,7 @@ "api_key = os.getenv(\"OPENAI_API_KEY\") or getpass.getpass(\"Enter your OpenAI API key: \")\n", "oaip = OpenAITextVectorizer(EMBEDDINGS_MODEL, api_config={\"api_key\": api_key})\n", "\n", - "chunked_data[\"embedding\"] = oaip.embed_many(chunked_data[\"content\"].tolist(), as_buffer=True)\n", + "chunked_data[\"embedding\"] = oaip.embed_many(chunked_data[\"content\"].tolist(), as_buffer=True, dtype=\"float32\")\n", "chunked_data" ] }, @@ -1073,7 +1073,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.12.2" }, "orig_nbformat": 4 }, diff --git a/docs/user_guide/vectorizers_04.ipynb b/docs/user_guide/vectorizers_04.ipynb index f9bc9b82..e3ca2772 100644 --- a/docs/user_guide/vectorizers_04.ipynb +++ b/docs/user_guide/vectorizers_04.ipynb @@ -356,7 +356,7 @@ "outputs": [], "source": [ "# You can also create many embeddings at once\n", - "embeddings = hf.embed_many(sentences, as_buffer=True)\n" + "embeddings = hf.embed_many(sentences, as_buffer=True, dtype=\"float32\")\n" ] }, { From 45718c714948c4ef7dc6befca6e30849a3d2da8b Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Sat, 21 Sep 2024 18:57:06 -0700 Subject: [PATCH 12/25] adds kargs to custom embedding function to allow for accepting dtype --- docs/user_guide/vectorizers_04.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user_guide/vectorizers_04.ipynb b/docs/user_guide/vectorizers_04.ipynb index e3ca2772..90b05892 100644 --- a/docs/user_guide/vectorizers_04.ipynb +++ b/docs/user_guide/vectorizers_04.ipynb @@ -569,7 +569,7 @@ "source": [ "from redisvl.utils.vectorize import CustomTextVectorizer\n", "\n", - "def generate_embeddings(text_input):\n", + "def generate_embeddings(text_input, **kwargs):\n", " return [0.101] * 768\n", "\n", "custom_vectorizer = CustomTextVectorizer(generate_embeddings)\n", From 0c215e4c12c757d6bfeea203990489656470524e Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Mon, 23 Sep 2024 10:57:25 -0700 Subject: [PATCH 13/25] updates docstring in semantic session to include overwrite argument --- redisvl/extensions/session_manager/semantic_session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 2f6b8cda..1d2c553b 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -52,7 +52,8 @@ def __init__( redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. connection_kwargs (Dict[str, Any]): The connection arguments for the redis client. Defaults to empty {}. - dtype (str): The data type for the prompt vector. Defaults to "float32". + overwrite (bool): Whether or not to force overwrite the schema for + the semantic session index. Defaults to false. The proposed schema will support a single vector embedding constructed from either the prompt or response in a single string. From b348a222ed33efcd5e417bd9b618b80f77d12c0e Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Mon, 23 Sep 2024 12:16:01 -0700 Subject: [PATCH 14/25] removes VectorDataTyps dict to use exisiting VectorDataType Enum --- redisvl/redis/utils.py | 23 +++++++---------------- redisvl/schema/schema.py | 2 +- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/redisvl/redis/utils.py b/redisvl/redis/utils.py index 3186fcfd..625a8833 100644 --- a/redisvl/redis/utils.py +++ b/redisvl/redis/utils.py @@ -4,12 +4,7 @@ import numpy as np from ml_dtypes import bfloat16 -VectorDataTypes = { - "BFLOAT16": bfloat16, - "FLOAT16": np.float16, - "FLOAT32": np.float32, - "FLOAT64": np.float64, -} +from redisvl.schema.fields import VectorDataType def make_dict(values: List[Any]) -> Dict[Any, Any]: @@ -40,24 +35,20 @@ def convert_bytes(data: Any) -> Any: def array_to_buffer(array: List[float], dtype: str) -> bytes: """Convert a list of floats into a numpy byte string.""" - try: - dtype = VectorDataTypes[dtype.upper()] - except KeyError: + if dtype.upper() not in VectorDataType: raise ValueError( - f"Invalid data type: {dtype}. Supported types are: {VectorDataTypes.keys()}" + f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}" ) - return np.array(array).astype(dtype).tobytes() + return np.array(array).astype(dtype.lower()).tobytes() def buffer_to_array(buffer: bytes, dtype: str) -> List[float]: """Convert bytes into into a list of floats.""" - try: - dtype = VectorDataTypes[dtype.upper()] - except KeyError: + if dtype.upper() not in VectorDataType: raise ValueError( - f"Invalid data type: {dtype}. Supported types are: {VectorDataTypes.keys()}" + f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}" ) - return np.frombuffer(buffer, dtype=dtype).tolist() + return np.frombuffer(buffer, dtype=dtype.lower()).tolist() def hashify(content: str) -> str: diff --git a/redisvl/schema/schema.py b/redisvl/schema/schema.py index fa6501d8..7f3db845 100644 --- a/redisvl/schema/schema.py +++ b/redisvl/schema/schema.py @@ -7,7 +7,7 @@ from pydantic.v1 import BaseModel, Field, root_validator from redis.commands.search.field import Field as RedisField -from redisvl.schema.fields import BaseField, FieldFactory, VectorDataType +from redisvl.schema.fields import BaseField, FieldFactory from redisvl.utils.log import get_logger from redisvl.utils.utils import model_to_dict From fc9712e7014fe9853a3f1627949275370aa2d954 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Mon, 23 Sep 2024 14:35:29 -0700 Subject: [PATCH 15/25] changes enum membership check for python 3.9 compatibility --- redisvl/extensions/router/semantic.py | 2 +- redisvl/index/storage.py | 3 +-- redisvl/redis/utils.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index aa87db72..7e3fd9b1 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -109,7 +109,7 @@ def _initialize_index( elif redis_url: self._index.connect(redis_url=redis_url, **connection_kwargs) - # Check for existing session index + # Check for existing router index existed = self._index.exists() if not overwrite and existed: existing_index = SearchIndex.from_existing( diff --git a/redisvl/index/storage.py b/redisvl/index/storage.py index ef2c50d9..209ea6f4 100644 --- a/redisvl/index/storage.py +++ b/redisvl/index/storage.py @@ -395,14 +395,13 @@ class HashStorage(BaseStorage): def _validate(self, obj: Dict[str, Any]): """Validate that the given object is a dictionary suitable for storage - as a Redis hash, and the vector byte string is of correct datatype. + as a Redis hash. Args: obj (Dict[str, Any]): The object to validate. Raises: TypeError: If the object is not a dictionary. - ValueError: if the vector byte string is not of correct datatype. """ if not isinstance(obj, dict): raise TypeError("Object must be a dictionary.") diff --git a/redisvl/redis/utils.py b/redisvl/redis/utils.py index 625a8833..2f89ff06 100644 --- a/redisvl/redis/utils.py +++ b/redisvl/redis/utils.py @@ -35,7 +35,7 @@ def convert_bytes(data: Any) -> Any: def array_to_buffer(array: List[float], dtype: str) -> bytes: """Convert a list of floats into a numpy byte string.""" - if dtype.upper() not in VectorDataType: + if dtype.upper() not in {v.value for v in VectorDataType}: raise ValueError( f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}" ) @@ -44,7 +44,7 @@ def array_to_buffer(array: List[float], dtype: str) -> bytes: def buffer_to_array(buffer: bytes, dtype: str) -> List[float]: """Convert bytes into into a list of floats.""" - if dtype.upper() not in VectorDataType: + if dtype.upper() not in {v.value for v in VectorDataType}: raise ValueError( f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}" ) From a8bf2df734b18251e40631311564c66d125fbaae Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Tue, 24 Sep 2024 09:51:52 -0700 Subject: [PATCH 16/25] changes dtype membership check --- redisvl/redis/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/redisvl/redis/utils.py b/redisvl/redis/utils.py index 2f89ff06..880f55c3 100644 --- a/redisvl/redis/utils.py +++ b/redisvl/redis/utils.py @@ -35,7 +35,9 @@ def convert_bytes(data: Any) -> Any: def array_to_buffer(array: List[float], dtype: str) -> bytes: """Convert a list of floats into a numpy byte string.""" - if dtype.upper() not in {v.value for v in VectorDataType}: + try: + VectorDataType(dtype.upper()) + except ValueError: raise ValueError( f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}" ) @@ -44,7 +46,9 @@ def array_to_buffer(array: List[float], dtype: str) -> bytes: def buffer_to_array(buffer: bytes, dtype: str) -> List[float]: """Convert bytes into into a list of floats.""" - if dtype.upper() not in {v.value for v in VectorDataType}: + try: + VectorDataType(dtype.upper()) + except ValueError: raise ValueError( f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}" ) From 89400bd29c68ef4958cc32fa521c71bcdd995f48 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Tue, 24 Sep 2024 15:19:41 -0700 Subject: [PATCH 17/25] updates GHA test workflow to skip vector dypes on unsupported redis-stack version --- .github/workflows/run_tests.yml | 7 ++++++- tests/integration/test_llmcache.py | 15 +++++++++++++-- tests/integration/test_semantic_router.py | 13 ++++++++++++- tests/integration/test_session_manager.py | 14 +++++++++++++- 4 files changed, 44 insertions(+), 5 deletions(-) diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 6b6b47d6..f37cd4ca 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -70,7 +70,12 @@ jobs: poetry run test-cov - name: Run tests - if: matrix.connection != 'plain' || matrix.redis-stack-version != 'latest' + if: matrix.redis-stack-version == '6.2.6-v9' + run: | + SKIP_DTYPES=True SKIP_VECTORIZERS=True SKIP_RERANKERS=True poetry run test-cov + + - name: Run tests + if: matrix.redis-stack-version == 'edge' || (matrix.connection == 'hiredis' && matrix.redis-stack-version == 'latest') run: | SKIP_VECTORIZERS=True SKIP_RERANKERS=True poetry run test-cov diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index ea6f7f7e..e0253588 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -1,4 +1,5 @@ import asyncio +import os from collections import namedtuple from time import sleep, time @@ -17,6 +18,13 @@ def vectorizer(): return HFTextVectorizer("sentence-transformers/all-mpnet-base-v2") +@pytest.fixture +def skip_dtypes() -> bool: + # os.getenv returns a string + v = os.getenv("SKIP_DTYPES", "False").lower() == "true" + return v + + @pytest.fixture def cache(vectorizer, redis_url): cache_instance = SemanticCache( @@ -802,7 +810,10 @@ def test_index_updating(redis_url): assert len(response) == 1 -def test_create_cache_with_different_vector_types(): +def test_create_cache_with_different_vector_types(skip_dtypes): + if skip_dtypes: + pytest.skip("Skipping dtype checking...") + bfloat_cache = SemanticCache(name="bfloat_cache", dtype="bfloat16") bfloat_cache.store("bfloat16 prompt", "bfloat16 response") @@ -821,7 +832,7 @@ def test_create_cache_with_different_vector_types(): def test_bad_dtype_connecting_to_existing_cache(): - cache1 = SemanticCache(name="float64_cache", dtype="float64") + cache = SemanticCache(name="float64_cache", dtype="float64") same_type = SemanticCache(name="float64_cache", dtype="float64") diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index 12487437..e7849ba4 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -1,3 +1,4 @@ +import os import pathlib import pytest @@ -8,6 +9,13 @@ from redisvl.redis.connection import compare_versions +@pytest.fixture +def skip_dtypes() -> bool: + # os.getenv returns a string + v = os.getenv("SKIP_DTYPES", "False").lower() == "true" + return v + + def get_base_path(): return pathlib.Path(__file__).parent.resolve() @@ -239,7 +247,10 @@ def test_bad_connection_info(routes): ) -def test_different_vector_dtypes(routes): +def test_different_vector_dtypes(routes, skip_dtypes): + if skip_dtypes: + pytest.skip("Skipping dtype checking...") + bfloat_router = SemanticRouter( name="bfloat_router", routes=routes, diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index ba482402..20f1211f 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -1,3 +1,5 @@ +import os + import pytest from redis.exceptions import ConnectionError @@ -7,6 +9,13 @@ ) +@pytest.fixture +def skip_dtypes() -> bool: + # os.getenv returns a string + v = os.getenv("SKIP_DTYPES", "False").lower() == "true" + return v + + @pytest.fixture def standard_session(app_name, client): session = StandardSessionManager(app_name, redis_client=client) @@ -538,7 +547,10 @@ def test_semantic_drop(semantic_session): ] -def test_different_vector_dtypes(): +def test_different_vector_dtypes(skip_dtypes): + if skip_dtypes: + pytest.skip("Skipping dtype checking...") + bfloat_sess = SemanticSessionManager(name="bfloat_session", dtype="bfloat16") bfloat_sess.add_message({"role": "user", "content": "bfloat message"}) From 41f569375df24f052dfc8426eabd3bb4aa233edf Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Tue, 24 Sep 2024 15:38:58 -0700 Subject: [PATCH 18/25] skips more dtype tests on old redis stack version --- tests/integration/test_llmcache.py | 5 ++++- tests/integration/test_semantic_router.py | 5 ++++- tests/integration/test_session_manager.py | 5 ++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index e0253588..19aed2d7 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -831,7 +831,10 @@ def test_create_cache_with_different_vector_types(skip_dtypes): assert len(cache.check("float prompt", num_results=5)) == 1 -def test_bad_dtype_connecting_to_existing_cache(): +def test_bad_dtype_connecting_to_existing_cache(skip_dtypes): + if skip_dtypes: + pytest.skip("Skipping dtype checking...") + cache = SemanticCache(name="float64_cache", dtype="float64") same_type = SemanticCache(name="float64_cache", dtype="float64") diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index e7849ba4..71422264 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -279,7 +279,10 @@ def test_different_vector_dtypes(routes, skip_dtypes): assert len(router.route_many("hello", max_k=5)) == 1 -def test_bad_dtype_connecting_to_exiting_router(routes): +def test_bad_dtype_connecting_to_exiting_router(routes, skip_dtypes): + if skip_dtypes: + pytest.skip("Skipping dtype checking...") + router = SemanticRouter( name="float64 router", routes=routes, diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index 20f1211f..3d112ff5 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -568,7 +568,10 @@ def test_different_vector_dtypes(skip_dtypes): assert len(sess.get_relevant("float message")) == 1 -def test_bad_dtype_connecting_to_exiting_session(): +def test_bad_dtype_connecting_to_exiting_session(skip_dtypes): + if skip_dtypes: + pytest.skip("Skipping dtype checking...") + session = SemanticSessionManager(name="float64 session", dtype="float64") same_type = SemanticSessionManager(name="float64 session", dtype="float64") From 6898f1de76bf93590e8503ec4d4e3a6aa6fd97a2 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Wed, 25 Sep 2024 12:10:11 -0700 Subject: [PATCH 19/25] removes dtype from class definitions, and uses constants instead --- redisvl/extensions/llmcache/schema.py | 6 ++--- redisvl/extensions/llmcache/semantic.py | 24 +++++++++---------- redisvl/extensions/router/semantic.py | 10 ++++---- .../session_manager/semantic_session.py | 12 ++++++---- redisvl/index/storage.py | 3 +-- tests/unit/test_llmcache_schema.py | 9 ++----- 6 files changed, 30 insertions(+), 34 deletions(-) diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index e90c8bf6..77903f7c 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -26,8 +26,6 @@ class CacheEntry(BaseModel): """Optional metadata stored on the cache entry""" filters: Optional[Dict[str, Any]] = Field(default=None) """Optional filter data stored on the cache entry for customizing retrieval""" - dtype: str - """The data type for the prompt vector.""" @root_validator(pre=True) @classmethod @@ -43,9 +41,9 @@ def non_empty_metadata(cls, v): raise TypeError("Metadata must be a dictionary.") return v - def to_dict(self) -> Dict: + def to_dict(self, dtype: str) -> Dict: data = self.dict(exclude_none=True) - data["prompt_vector"] = array_to_buffer(self.prompt_vector, self.dtype) + data["prompt_vector"] = array_to_buffer(self.prompt_vector, dtype) if self.metadata is not None: data["metadata"] = serialize(self.metadata) if self.filters is not None: diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index d57861af..3c165a09 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -15,6 +15,8 @@ from redisvl.utils.utils import current_timestamp, serialize, validate_vector_dims from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer +VECTOR_FIELD_NAME = "prompt_vector" ### + class SemanticCache(BaseLLMCache): """Semantic Cache for Large Language Models.""" @@ -23,7 +25,7 @@ class SemanticCache(BaseLLMCache): entry_id_field_name: str = "entry_id" prompt_field_name: str = "prompt" response_field_name: str = "response" - vector_field_name: str = "prompt_vector" + ###vector_field_name: str = "prompt_vector" inserted_at_field_name: str = "inserted_at" updated_at_field_name: str = "updated_at" metadata_field_name: str = "metadata" @@ -136,9 +138,10 @@ def __init__( validate_vector_dims( vectorizer.dims, - self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore + self._index.schema.fields[VECTOR_FIELD_NAME].attrs.dims, # type: ignore ) self._vectorizer = vectorizer + self._dtype = self.index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype # type: ignore[union-attr] def _modify_schema( self, @@ -290,8 +293,7 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: if not isinstance(prompt, str): raise TypeError("Prompt must be a string.") - dtype = self.index.schema.fields[self.vector_field_name].attrs.datatype # type: ignore[union-attr] - return self._vectorizer.embed(prompt, dtype=dtype) + return self._vectorizer.embed(prompt, dtype=self._dtype) async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: """Converts a text prompt to its vector representation using the @@ -304,7 +306,7 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: def _check_vector_dims(self, vector: List[float]): """Checks the size of the provided vector and raises an error if it doesn't match the search index vector dimensions.""" - schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore + schema_vector_dims = self._index.schema.fields[VECTOR_FIELD_NAME].attrs.dims # type: ignore validate_vector_dims(len(vector), schema_vector_dims) def check( @@ -367,13 +369,13 @@ def check( query = RangeQuery( vector=vector, - vector_field_name=self.vector_field_name, + vector_field_name=VECTOR_FIELD_NAME, return_fields=self.return_fields, distance_threshold=distance_threshold, num_results=num_results, return_score=True, filter_expression=filter_expression, - dtype=self.index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] + dtype=self._dtype, ) # Search the cache! @@ -449,7 +451,7 @@ async def acheck( query = RangeQuery( vector=vector, - vector_field_name=self.vector_field_name, + vector_field_name=VECTOR_FIELD_NAME, return_fields=self.return_fields, distance_threshold=distance_threshold, num_results=num_results, @@ -539,13 +541,12 @@ def store( prompt_vector=vector, metadata=metadata, filters=filters, - dtype=self.index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] ) # Load cache entry with TTL ttl = ttl or self._ttl keys = self._index.load( - data=[cache_entry.to_dict()], + data=[cache_entry.to_dict(self._dtype)], ttl=ttl, id_field=self.entry_id_field_name, ) @@ -604,13 +605,12 @@ async def astore( prompt_vector=vector, metadata=metadata, filters=filters, - dtype=self.index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] ) # Load cache entry with TTL ttl = ttl or self._ttl keys = await aindex.load( - data=[cache_entry.to_dict()], + data=[cache_entry.to_dict(self._dtype)], ttl=ttl, id_field=self.entry_id_field_name, ) diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 7e3fd9b1..66d38507 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -28,6 +28,8 @@ logger = get_logger(__name__) +VECTOR_FIELD_NAME = "vector" ### + class SemanticRouter(BaseModel): """Semantic Router for managing and querying route vectors.""" @@ -40,7 +42,7 @@ class SemanticRouter(BaseModel): """The vectorizer used to embed route references.""" routing_config: RoutingConfig = Field(default_factory=RoutingConfig) """Configuration for routing behavior.""" - vector_field_name: str = "vector" + ### vector_field_name: str = "vector" _index: SearchIndex = PrivateAttr() @@ -171,7 +173,7 @@ def _add_routes(self, routes: List[Route]): reference_vectors = self.vectorizer.embed_many( [reference for reference in route.references], as_buffer=True, - dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] + dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) # set route references for i, reference in enumerate(route.references): @@ -248,7 +250,7 @@ def _classify_route( vector_field_name="vector", distance_threshold=distance_threshold, return_fields=["route_name"], - dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] + dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) aggregate_request = self._build_aggregate_request( @@ -301,7 +303,7 @@ def _classify_multi_route( vector_field_name="vector", distance_threshold=distance_threshold, return_fields=["route_name"], - dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] + dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) aggregate_request = self._build_aggregate_request( vector_range_query, aggregation_method, max_k diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 1d2c553b..9835dc11 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -13,9 +13,11 @@ from redisvl.utils.utils import validate_vector_dims from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer +VECTOR_FIELD_NAME = "vector_field" ### + class SemanticSessionManager(BaseSessionManager): - vector_field_name: str = "vector_field" + ###vector_field_name: str = "vector_field" def __init__( self, @@ -201,13 +203,13 @@ def get_relevant( query = RangeQuery( vector=self._vectorizer.embed(prompt), - vector_field_name=self.vector_field_name, + vector_field_name=VECTOR_FIELD_NAME, return_fields=return_fields, distance_threshold=distance_threshold, num_results=top_k, return_score=True, filter_expression=session_filter, - dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] + dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) messages = self._index.query(query) @@ -321,7 +323,7 @@ def add_messages( content_vector = self._vectorizer.embed(message[self.content_field_name]) validate_vector_dims( len(content_vector), - self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore + self._index.schema.fields[VECTOR_FIELD_NAME].attrs.dims, # type: ignore ) chat_message = ChatMessage( @@ -329,7 +331,7 @@ def add_messages( content=message[self.content_field_name], session_tag=session_tag, vector_field=content_vector, - dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] + dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) if self.tool_field_name in message: diff --git a/redisvl/index/storage.py b/redisvl/index/storage.py index 209ea6f4..12ef2052 100644 --- a/redisvl/index/storage.py +++ b/redisvl/index/storage.py @@ -2,7 +2,6 @@ import uuid from typing import Any, Callable, Dict, Iterable, List, Optional -from numpy import frombuffer from pydantic.v1 import BaseModel from redis import Redis from redis.asyncio import Redis as AsyncRedis @@ -394,7 +393,7 @@ class HashStorage(BaseStorage): """Hash data type for the index""" def _validate(self, obj: Dict[str, Any]): - """Validate that the given object is a dictionary suitable for storage + """Validate that the given object is a dictionary, suitable for storage as a Redis hash. Args: diff --git a/tests/unit/test_llmcache_schema.py b/tests/unit/test_llmcache_schema.py index 7c1755d2..f210efce 100644 --- a/tests/unit/test_llmcache_schema.py +++ b/tests/unit/test_llmcache_schema.py @@ -12,7 +12,6 @@ def test_valid_cache_entry_creation(): prompt="What is AI?", response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], - dtype="float16", ) assert entry.entry_id == hashify("What is AI?") assert entry.prompt == "What is AI?" @@ -26,7 +25,6 @@ def test_cache_entry_with_given_entry_id(): prompt="What is AI?", response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], - dtype="float16", ) assert entry.entry_id == "custom_id" @@ -38,7 +36,6 @@ def test_cache_entry_with_invalid_metadata(): response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], metadata="invalid_metadata", - dtype="float64", ) @@ -49,9 +46,8 @@ def test_cache_entry_to_dict(): prompt_vector=[0.1, 0.2, 0.3], metadata={"author": "John"}, filters={"category": "technology"}, - dtype="float32", ) - result = entry.to_dict() + result = entry.to_dict(dtype="float32") assert result["entry_id"] == hashify("What is AI?") assert result["metadata"] == json.dumps({"author": "John"}) assert result["prompt_vector"] == array_to_buffer([0.1, 0.2, 0.3], "float32") @@ -112,9 +108,8 @@ def test_cache_entry_with_empty_optional_fields(): prompt="What is AI?", response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], - dtype="bfloat16", ) - result = entry.to_dict() + result = entry.to_dict(dtype="float32") assert "metadata" not in result assert "filters" not in result From 18ccf16482a54ea7e276138bda226bae876d4799 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Wed, 25 Sep 2024 15:24:10 -0700 Subject: [PATCH 20/25] lowers required search module to 2.6.20 --- redisvl/extensions/llmcache/semantic.py | 3 +-- redisvl/extensions/router/semantic.py | 3 +-- redisvl/extensions/session_manager/semantic_session.py | 3 +-- redisvl/index/index.py | 2 +- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 3c165a09..8ba7838e 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -15,7 +15,7 @@ from redisvl.utils.utils import current_timestamp, serialize, validate_vector_dims from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer -VECTOR_FIELD_NAME = "prompt_vector" ### +VECTOR_FIELD_NAME = "prompt_vector" class SemanticCache(BaseLLMCache): @@ -25,7 +25,6 @@ class SemanticCache(BaseLLMCache): entry_id_field_name: str = "entry_id" prompt_field_name: str = "prompt" response_field_name: str = "response" - ###vector_field_name: str = "prompt_vector" inserted_at_field_name: str = "inserted_at" updated_at_field_name: str = "updated_at" metadata_field_name: str = "metadata" diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 66d38507..e4b98469 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -28,7 +28,7 @@ logger = get_logger(__name__) -VECTOR_FIELD_NAME = "vector" ### +VECTOR_FIELD_NAME = "vector" class SemanticRouter(BaseModel): @@ -42,7 +42,6 @@ class SemanticRouter(BaseModel): """The vectorizer used to embed route references.""" routing_config: RoutingConfig = Field(default_factory=RoutingConfig) """Configuration for routing behavior.""" - ### vector_field_name: str = "vector" _index: SearchIndex = PrivateAttr() diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 9835dc11..51c1cbf5 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -13,11 +13,10 @@ from redisvl.utils.utils import validate_vector_dims from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer -VECTOR_FIELD_NAME = "vector_field" ### +VECTOR_FIELD_NAME = "vector_field" class SemanticSessionManager(BaseSessionManager): - ###vector_field_name: str = "vector_field" def __init__( self, diff --git a/redisvl/index/index.py b/redisvl/index/index.py index b4adb6b3..c6a371d2 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -354,7 +354,7 @@ def from_existing( # Validate modules installed_modules = RedisConnectionFactory.get_modules(redis_client) - validate_modules(installed_modules, [{"name": "search", "ver": 20810}]) + validate_modules(installed_modules, [{"name": "search", "ver": 20620}]) # Fetch index info and convert to schema index_info = cls._info(name, redis_client) From 816eeeb2eb8c8eb4dadec409b405126a793b34c3 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Wed, 25 Sep 2024 15:31:03 -0700 Subject: [PATCH 21/25] lowers required search module to 2.6.12 --- redisvl/index/index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index c6a371d2..cd23c1cb 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -354,7 +354,7 @@ def from_existing( # Validate modules installed_modules = RedisConnectionFactory.get_modules(redis_client) - validate_modules(installed_modules, [{"name": "search", "ver": 20620}]) + validate_modules(installed_modules, [{"name": "search", "ver": 20612}]) # Fetch index info and convert to schema index_info = cls._info(name, redis_client) From d4ad2d251b0ccf641c6c21db0777dcb788198e6b Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Wed, 25 Sep 2024 16:48:16 -0700 Subject: [PATCH 22/25] super hacky fix to version compatibility issue --- redisvl/redis/connection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 21095cde..7a20be78 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -90,7 +90,9 @@ def convert_index_info_to_schema(index_info: Dict[str, Any]) -> Dict[str, Any]: def parse_vector_attrs(attrs): vector_attrs = {attrs[i].lower(): attrs[i + 1] for i in range(6, len(attrs), 2)} - vector_attrs["dims"] = int(vector_attrs.pop("dim")) + vector_attrs["dims"] = int( + vector_attrs.pop("dim" if "dim" in vector_attrs else "dims") + ) vector_attrs["distance_metric"] = vector_attrs.pop("distance_metric").lower() vector_attrs["algorithm"] = vector_attrs.pop("algorithm").lower() vector_attrs["datatype"] = vector_attrs.pop("data_type").lower() From f0efe5c0ac542b0458338e7c6e657a2f19eeb51c Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Wed, 25 Sep 2024 17:10:30 -0700 Subject: [PATCH 23/25] reverts hacky fix and module version checked --- redisvl/index/index.py | 2 +- redisvl/redis/connection.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index cd23c1cb..b4adb6b3 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -354,7 +354,7 @@ def from_existing( # Validate modules installed_modules = RedisConnectionFactory.get_modules(redis_client) - validate_modules(installed_modules, [{"name": "search", "ver": 20612}]) + validate_modules(installed_modules, [{"name": "search", "ver": 20810}]) # Fetch index info and convert to schema index_info = cls._info(name, redis_client) diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 7a20be78..21095cde 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -90,9 +90,7 @@ def convert_index_info_to_schema(index_info: Dict[str, Any]) -> Dict[str, Any]: def parse_vector_attrs(attrs): vector_attrs = {attrs[i].lower(): attrs[i + 1] for i in range(6, len(attrs), 2)} - vector_attrs["dims"] = int( - vector_attrs.pop("dim" if "dim" in vector_attrs else "dims") - ) + vector_attrs["dims"] = int(vector_attrs.pop("dim")) vector_attrs["distance_metric"] = vector_attrs.pop("distance_metric").lower() vector_attrs["algorithm"] = vector_attrs.pop("algorithm").lower() vector_attrs["datatype"] = vector_attrs.pop("data_type").lower() From 462af6c61f6a189cd13676c3eefff92e2a2650eb Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Tue, 1 Oct 2024 17:36:36 -0700 Subject: [PATCH 24/25] removes local vector field name constant --- redisvl/extensions/session_manager/semantic_session.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index f9256f59..ce6c8f0d 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -22,8 +22,6 @@ from redisvl.utils.utils import validate_vector_dims from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer -VECTOR_FIELD_NAME = "vector_field" - class SemanticSessionManager(BaseSessionManager): @@ -217,7 +215,7 @@ def get_relevant( num_results=top_k, return_score=True, filter_expression=session_filter, - dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] + dtype=self._index.schema.fields[SESSION_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) messages = self._index.query(query) @@ -343,7 +341,7 @@ def add_messages( if TOOL_FIELD_NAME in message: chat_message.tool_call_id = message[TOOL_FIELD_NAME] - chat_messages.append(chat_message.to_dict(dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype)) # type: ignore[union-attr] + chat_messages.append(chat_message.to_dict(dtype=self._index.schema.fields[SESSION_VECTOR_FIELD_NAME].attrs.datatype)) # type: ignore[union-attr] self._index.load(data=chat_messages, id_field=ID_FIELD_NAME) From 54d3e70ebb93d3bea0ef865a5f57cefc3275f1cd Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Tue, 1 Oct 2024 17:38:21 -0700 Subject: [PATCH 25/25] fixes tests that fail with redis-stack 6.2.6 --- .github/workflows/run_tests.yml | 7 +- tests/integration/test_llmcache.py | 51 ++++++------- tests/integration/test_semantic_router.py | 91 +++++++++++------------ tests/integration/test_session_manager.py | 50 ++++++------- 4 files changed, 85 insertions(+), 114 deletions(-) diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index f37cd4ca..6b6b47d6 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -70,12 +70,7 @@ jobs: poetry run test-cov - name: Run tests - if: matrix.redis-stack-version == '6.2.6-v9' - run: | - SKIP_DTYPES=True SKIP_VECTORIZERS=True SKIP_RERANKERS=True poetry run test-cov - - - name: Run tests - if: matrix.redis-stack-version == 'edge' || (matrix.connection == 'hiredis' && matrix.redis-stack-version == 'latest') + if: matrix.connection != 'plain' || matrix.redis-stack-version != 'latest' run: | SKIP_VECTORIZERS=True SKIP_RERANKERS=True poetry run test-cov diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index a3307ac2..4eb4d6f7 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -18,13 +18,6 @@ def vectorizer(): return HFTextVectorizer("sentence-transformers/all-mpnet-base-v2") -@pytest.fixture -def skip_dtypes() -> bool: - # os.getenv returns a string - v = os.getenv("SKIP_DTYPES", "False").lower() == "true" - return v - - @pytest.fixture def cache(vectorizer, redis_url): cache_instance = SemanticCache( @@ -829,7 +822,6 @@ def test_no_key_collision_on_identical_prompts(redis_url): private_cache.store( prompt="What's the phone number linked in my account?", response="The number on file is 123-555-9999", - ###filters={"user_id": "cerioni"}, filters={"user_id": "cerioni", "zip_code": 90210}, ) @@ -853,34 +845,33 @@ def test_no_key_collision_on_identical_prompts(redis_url): assert len(filtered_results) == 2 -def test_create_cache_with_different_vector_types(skip_dtypes): - if skip_dtypes: - pytest.skip("Skipping dtype checking...") - - bfloat_cache = SemanticCache(name="bfloat_cache", dtype="bfloat16") - bfloat_cache.store("bfloat16 prompt", "bfloat16 response") - - float16_cache = SemanticCache(name="float16_cache", dtype="float16") - float16_cache.store("float16 prompt", "float16 response") - - float32_cache = SemanticCache(name="float32_cache", dtype="float32") - float32_cache.store("float32 prompt", "float32 response") +def test_create_cache_with_different_vector_types(): + try: + bfloat_cache = SemanticCache(name="bfloat_cache", dtype="bfloat16") + bfloat_cache.store("bfloat16 prompt", "bfloat16 response") - float64_cache = SemanticCache(name="float64_cache", dtype="float64") - float64_cache.store("float64 prompt", "float64 response") + float16_cache = SemanticCache(name="float16_cache", dtype="float16") + float16_cache.store("float16 prompt", "float16 response") - for cache in [bfloat_cache, float16_cache, float32_cache, float64_cache]: - cache.set_threshold(0.6) - assert len(cache.check("float prompt", num_results=5)) == 1 + float32_cache = SemanticCache(name="float32_cache", dtype="float32") + float32_cache.store("float32 prompt", "float32 response") + float64_cache = SemanticCache(name="float64_cache", dtype="float64") + float64_cache.store("float64 prompt", "float64 response") -def test_bad_dtype_connecting_to_existing_cache(skip_dtypes): - if skip_dtypes: - pytest.skip("Skipping dtype checking...") + for cache in [bfloat_cache, float16_cache, float32_cache, float64_cache]: + cache.set_threshold(0.6) + assert len(cache.check("float prompt", num_results=5)) == 1 + except: + pytest.skip("Not using a late enough version of Redis") - cache = SemanticCache(name="float64_cache", dtype="float64") - same_type = SemanticCache(name="float64_cache", dtype="float64") +def test_bad_dtype_connecting_to_existing_cache(): + try: + cache = SemanticCache(name="float64_cache", dtype="float64") + same_type = SemanticCache(name="float64_cache", dtype="float64") + except ValueError: + pytest.skip("Not using a late enough version of Redis") with pytest.raises(ValueError): bad_type = SemanticCache(name="float64_cache", dtype="float16") diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index 71422264..194a6f98 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -9,13 +9,6 @@ from redisvl.redis.connection import compare_versions -@pytest.fixture -def skip_dtypes() -> bool: - # os.getenv returns a string - v = os.getenv("SKIP_DTYPES", "False").lower() == "true" - return v - - def get_base_path(): return pathlib.Path(__file__).parent.resolve() @@ -183,7 +176,7 @@ def test_to_dict(semantic_router): def test_from_dict(semantic_router): router_dict = semantic_router.to_dict() new_router = SemanticRouter.from_dict( - router_dict, redis_client=semantic_router._index.client + router_dict, redis_client=semantic_router._index.client, overwrite=True ) assert new_router == semantic_router @@ -231,7 +224,7 @@ def test_yaml_invalid_file_path(): def test_idempotent_to_dict(semantic_router): router_dict = semantic_router.to_dict() new_router = SemanticRouter.from_dict( - router_dict, redis_client=semantic_router._index.client + router_dict, redis_client=semantic_router._index.client, overwrite=True ) assert new_router.to_dict() == router_dict @@ -247,53 +240,53 @@ def test_bad_connection_info(routes): ) -def test_different_vector_dtypes(routes, skip_dtypes): - if skip_dtypes: - pytest.skip("Skipping dtype checking...") - - bfloat_router = SemanticRouter( - name="bfloat_router", - routes=routes, - dtype="bfloat16", - ) - - float16_router = SemanticRouter( - name="float16_router", - routes=routes, - dtype="float16", - ) +def test_different_vector_dtypes(routes): + try: + bfloat_router = SemanticRouter( + name="bfloat_router", + routes=routes, + dtype="bfloat16", + ) - float32_router = SemanticRouter( - name="float32_router", - routes=routes, - dtype="float32", - ) + float16_router = SemanticRouter( + name="float16_router", + routes=routes, + dtype="float16", + ) - float64_router = SemanticRouter( - name="float64_router", - routes=routes, - dtype="float64", - ) + float32_router = SemanticRouter( + name="float32_router", + routes=routes, + dtype="float32", + ) - for router in [bfloat_router, float16_router, float32_router, float64_router]: - assert len(router.route_many("hello", max_k=5)) == 1 + float64_router = SemanticRouter( + name="float64_router", + routes=routes, + dtype="float64", + ) + for router in [bfloat_router, float16_router, float32_router, float64_router]: + assert len(router.route_many("hello", max_k=5)) == 1 + except: + pytest.skip("Not using a late enough version of Redis") -def test_bad_dtype_connecting_to_exiting_router(routes, skip_dtypes): - if skip_dtypes: - pytest.skip("Skipping dtype checking...") - router = SemanticRouter( - name="float64 router", - routes=routes, - dtype="float64", - ) +def test_bad_dtype_connecting_to_exiting_router(routes): + try: + router = SemanticRouter( + name="float64 router", + routes=routes, + dtype="float64", + ) - same_type = SemanticRouter( - name="float64 router", - routes=routes, - dtype="float64", - ) + same_type = SemanticRouter( + name="float64 router", + routes=routes, + dtype="float64", + ) + except ValueError: + pytest.skip("Not using a late enough version of Redis") with pytest.raises(ValueError): bad_type = SemanticRouter( diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index adc74988..898c3ab5 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -10,13 +10,6 @@ ) -@pytest.fixture -def skip_dtypes() -> bool: - # os.getenv returns a string - v = os.getenv("SKIP_DTYPES", "False").lower() == "true" - return v - - @pytest.fixture def standard_session(app_name, client): session = StandardSessionManager(app_name, redis_client=client) @@ -548,34 +541,33 @@ def test_semantic_drop(semantic_session): ] -def test_different_vector_dtypes(skip_dtypes): - if skip_dtypes: - pytest.skip("Skipping dtype checking...") - - bfloat_sess = SemanticSessionManager(name="bfloat_session", dtype="bfloat16") - bfloat_sess.add_message({"role": "user", "content": "bfloat message"}) - - float16_sess = SemanticSessionManager(name="float16_session", dtype="float16") - float16_sess.add_message({"role": "user", "content": "float16 message"}) - - float32_sess = SemanticSessionManager(name="float32_session", dtype="float32") - float32_sess.add_message({"role": "user", "content": "float32 message"}) +def test_different_vector_dtypes(): + try: + bfloat_sess = SemanticSessionManager(name="bfloat_session", dtype="bfloat16") + bfloat_sess.add_message({"role": "user", "content": "bfloat message"}) - float64_sess = SemanticSessionManager(name="float64_session", dtype="float64") - float64_sess.add_message({"role": "user", "content": "float64 message"}) + float16_sess = SemanticSessionManager(name="float16_session", dtype="float16") + float16_sess.add_message({"role": "user", "content": "float16 message"}) - for sess in [bfloat_sess, float16_sess, float32_sess, float64_sess]: - sess.set_distance_threshold(0.7) - assert len(sess.get_relevant("float message")) == 1 + float32_sess = SemanticSessionManager(name="float32_session", dtype="float32") + float32_sess.add_message({"role": "user", "content": "float32 message"}) + float64_sess = SemanticSessionManager(name="float64_session", dtype="float64") + float64_sess.add_message({"role": "user", "content": "float64 message"}) -def test_bad_dtype_connecting_to_exiting_session(skip_dtypes): - if skip_dtypes: - pytest.skip("Skipping dtype checking...") + for sess in [bfloat_sess, float16_sess, float32_sess, float64_sess]: + sess.set_distance_threshold(0.7) + assert len(sess.get_relevant("float message")) == 1 + except: + pytest.skip("Not using a late enough version of Redis") - session = SemanticSessionManager(name="float64 session", dtype="float64") - same_type = SemanticSessionManager(name="float64 session", dtype="float64") +def test_bad_dtype_connecting_to_exiting_session(): + try: + session = SemanticSessionManager(name="float64 session", dtype="float64") + same_type = SemanticSessionManager(name="float64 session", dtype="float64") + except ValueError: + pytest.skip("Not using a late enough version of Redis") with pytest.raises(ValueError): bad_type = SemanticSessionManager(name="float64 session", dtype="float16")