Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds support for new vector data types #222

Merged
merged 27 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b0d8e91
adds ml_dtypes as dependency for Bfloat16
justin-cechmanek Sep 17, 2024
80cdd32
wip: tests pass with float32 default, except test_to_yaml_and_reload
justin-cechmanek Sep 18, 2024
17d9931
wip: removes dtype from IndexInfo and reads it from field attrs
justin-cechmanek Sep 19, 2024
95b1ec0
adds session manager schema checks
justin-cechmanek Sep 19, 2024
a65eca7
cleans up session manager vector typing
justin-cechmanek Sep 19, 2024
947ccaf
updates router to specify vector dtype
justin-cechmanek Sep 19, 2024
196ca0a
makes dtype required for CacheEntry
justin-cechmanek Sep 19, 2024
80b919d
Merge branch 'main' into feat/RAAE-206/vector-dtypes
justin-cechmanek Sep 20, 2024
2a2d5b4
formatting
justin-cechmanek Sep 20, 2024
9b0d86a
addressing PR comments
justin-cechmanek Sep 20, 2024
0b98281
changes dtype arg to string in notebook
justin-cechmanek Sep 22, 2024
d289a44
specifies vector dtype when creating byte vectors in notebooks
justin-cechmanek Sep 22, 2024
45718c7
adds kargs to custom embedding function to allow for accepting dtype
justin-cechmanek Sep 22, 2024
0c215e4
updates docstring in semantic session to include overwrite argument
justin-cechmanek Sep 23, 2024
b348a22
removes VectorDataTyps dict to use exisiting VectorDataType Enum
justin-cechmanek Sep 23, 2024
fc9712e
changes enum membership check for python 3.9 compatibility
justin-cechmanek Sep 23, 2024
a8bf2df
changes dtype membership check
justin-cechmanek Sep 24, 2024
89400bd
updates GHA test workflow to skip vector dypes on unsupported redis-s…
justin-cechmanek Sep 24, 2024
41f5693
skips more dtype tests on old redis stack version
justin-cechmanek Sep 24, 2024
6898f1d
removes dtype from class definitions, and uses constants instead
justin-cechmanek Sep 25, 2024
18ccf16
lowers required search module to 2.6.20
justin-cechmanek Sep 25, 2024
816eeeb
lowers required search module to 2.6.12
justin-cechmanek Sep 25, 2024
d4ad2d2
super hacky fix to version compatibility issue
justin-cechmanek Sep 25, 2024
f0efe5c
reverts hacky fix and module version checked
justin-cechmanek Sep 26, 2024
2de27ba
Merge branch 'main' into feat/RAAE-206/vector-dtypes
justin-cechmanek Sep 30, 2024
462af6c
removes local vector field name constant
justin-cechmanek Oct 2, 2024
54d3e70
fixes tests that fail with redis-stack 6.2.6
justin-cechmanek Oct 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api/schema.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Each field type supports specific attributes that customize its behavior. Below

- `dims`: Dimensionality of the vector.
- `algorithm`: Indexing algorithm (`flat` or `hnsw`).
- `datatype`: Float datatype of the vector (`float32` or `float64`).
- `datatype`: Float datatype of the vector (`bfloat16`, `float16`, `float32`, `float64`).
- `distance_metric`: Metric for measuring query relevance (`COSINE`, `L2`, `IP`).

**HNSW Vector Field Specific Attributes**:
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/openai_qna.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@
"api_key = os.getenv(\"OPENAI_API_KEY\") or getpass.getpass(\"Enter your OpenAI API key: \")\n",
"oaip = OpenAITextVectorizer(EMBEDDINGS_MODEL, api_config={\"api_key\": api_key})\n",
"\n",
"chunked_data[\"embedding\"] = oaip.embed_many(chunked_data[\"content\"].tolist(), as_buffer=True)\n",
"chunked_data[\"embedding\"] = oaip.embed_many(chunked_data[\"content\"].tolist(), as_buffer=True, dtype=\"float32\")\n",
"chunked_data"
]
},
Expand Down Expand Up @@ -1073,7 +1073,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.12.2"
},
"orig_nbformat": 4
},
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guide/hash_vs_json_05.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@
"json_data = data.copy()\n",
"\n",
"for d in json_data:\n",
" d['user_embedding'] = buffer_to_array(d['user_embedding'], dtype=np.float32)"
" d['user_embedding'] = buffer_to_array(d['user_embedding'], dtype='float32')"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/user_guide/vectorizers_04.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@
"outputs": [],
"source": [
"# You can also create many embeddings at once\n",
"embeddings = hf.embed_many(sentences, as_buffer=True)\n"
"embeddings = hf.embed_many(sentences, as_buffer=True, dtype=\"float32\")\n"
]
},
{
Expand Down Expand Up @@ -569,7 +569,7 @@
"source": [
"from redisvl.utils.vectorize import CustomTextVectorizer\n",
"\n",
"def generate_embeddings(text_input):\n",
"def generate_embeddings(text_input, **kwargs):\n",
" return [0.101] * 768\n",
"\n",
"custom_vectorizer = CustomTextVectorizer(generate_embeddings)\n",
Expand Down
45 changes: 41 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ redis = ">=5.0.0"
pydantic = { version = ">=2,<3" }
tenacity = ">=8.2.2"
tabulate = { version = ">=0.9.0,<1" }
ml-dtypes = "^0.4.0"
openai = { version = ">=1.13.0", optional = true }
sentence-transformers = { version = ">=2.2.2", optional = true }
google-cloud-aiplatform = { version = ">=1.26", optional = true }
Expand Down
8 changes: 4 additions & 4 deletions redisvl/extensions/llmcache/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def non_empty_metadata(cls, v):
raise TypeError("Metadata must be a dictionary.")
return v

def to_dict(self) -> Dict:
def to_dict(self, dtype: str) -> Dict:
data = self.dict(exclude_none=True)
data["prompt_vector"] = array_to_buffer(self.prompt_vector)
data["prompt_vector"] = array_to_buffer(self.prompt_vector, dtype)
if self.metadata is not None:
data["metadata"] = serialize(self.metadata)
if self.filters is not None:
Expand Down Expand Up @@ -112,7 +112,7 @@ def to_dict(self) -> Dict:
class SemanticCacheIndexSchema(IndexSchema):

@classmethod
def from_params(cls, name: str, prefix: str, vector_dims: int):
def from_params(cls, name: str, prefix: str, vector_dims: int, dtype: str):

return cls(
index={"name": name, "prefix": prefix}, # type: ignore
Expand All @@ -126,7 +126,7 @@ def from_params(cls, name: str, prefix: str, vector_dims: int):
"type": "vector",
"attrs": {
"dims": vector_dims,
"datatype": "float32",
"datatype": dtype,
"distance_metric": "cosine",
"algorithm": "flat",
},
Expand Down
13 changes: 9 additions & 4 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ def __init__(
]

# Create semantic cache schema and index
schema = SemanticCacheIndexSchema.from_params(name, prefix, vectorizer.dims)
dtype = kwargs.get("dtype", "float32")
justin-cechmanek marked this conversation as resolved.
Show resolved Hide resolved
schema = SemanticCacheIndexSchema.from_params(
name, prefix, vectorizer.dims, dtype
)
schema = self._modify_schema(schema, filterable_fields)
self._index = SearchIndex(schema=schema)

Expand Down Expand Up @@ -137,6 +140,7 @@ def __init__(
self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims, # type: ignore
)
self._vectorizer = vectorizer
self._dtype = self.index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.datatype # type: ignore[union-attr]

def _modify_schema(
self,
Expand Down Expand Up @@ -286,7 +290,7 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
if not isinstance(prompt, str):
raise TypeError("Prompt must be a string.")

return self._vectorizer.embed(prompt)
return self._vectorizer.embed(prompt, dtype=self._dtype)

async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
"""Converts a text prompt to its vector representation using the
Expand Down Expand Up @@ -368,6 +372,7 @@ def check(
num_results=num_results,
return_score=True,
filter_expression=filter_expression,
dtype=self._dtype,
)

# Search the cache!
Expand Down Expand Up @@ -538,7 +543,7 @@ def store(
# Load cache entry with TTL
ttl = ttl or self._ttl
keys = self._index.load(
data=[cache_entry.to_dict()],
data=[cache_entry.to_dict(self._dtype)],
ttl=ttl,
id_field=ENTRY_ID_FIELD_NAME,
)
Expand Down Expand Up @@ -602,7 +607,7 @@ async def astore(
# Load cache entry with TTL
ttl = ttl or self._ttl
keys = await aindex.load(
data=[cache_entry.to_dict()],
data=[cache_entry.to_dict(self._dtype)],
ttl=ttl,
id_field=ENTRY_ID_FIELD_NAME,
)
Expand Down
8 changes: 4 additions & 4 deletions redisvl/extensions/router/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic.v1 import BaseModel, Field, validator

from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
from redisvl.schema import IndexInfo, IndexSchema
from redisvl.schema import IndexSchema


class Route(BaseModel):
Expand Down Expand Up @@ -89,7 +89,7 @@ class SemanticRouterIndexSchema(IndexSchema):
"""Customized index schema for SemanticRouter."""

@classmethod
def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema":
def from_params(cls, name: str, vector_dims: int, dtype: str):
"""Create an index schema based on router name and vector dimensions.

Args:
Expand All @@ -100,7 +100,7 @@ def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema"
SemanticRouterIndexSchema: The constructed index schema.
"""
return cls(
index=IndexInfo(name=name, prefix=name),
index={"name": name, "prefix": name}, # type: ignore
fields=[ # type: ignore
{"name": "route_name", "type": "tag"},
{"name": "reference", "type": "text"},
Expand All @@ -111,7 +111,7 @@ def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema"
"algorithm": "flat",
"dims": vector_dims,
"distance_metric": "cosine",
"datatype": "float32",
"datatype": dtype,
},
},
],
Expand Down
28 changes: 24 additions & 4 deletions redisvl/extensions/router/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,26 +85,42 @@ def __init__(
vectorizer=vectorizer,
routing_config=routing_config,
)
self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs)
dtype = kwargs.get("dtype", "float32")
self._initialize_index(
redis_client, redis_url, overwrite, dtype, **connection_kwargs
)

def _initialize_index(
self,
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
overwrite: bool = False,
dtype: str = "float32",
**connection_kwargs,
):
"""Initialize the search index and handle Redis connection."""
schema = SemanticRouterIndexSchema.from_params(self.name, self.vectorizer.dims)
schema = SemanticRouterIndexSchema.from_params(
self.name, self.vectorizer.dims, dtype
)
self._index = SearchIndex(schema=schema)

if redis_client:
self._index.set_client(redis_client)
elif redis_url:
self._index.connect(redis_url=redis_url, **connection_kwargs)

# Check for existing router index
existed = self._index.exists()
self._index.create(overwrite=overwrite)
if not overwrite and existed:
existing_index = SearchIndex.from_existing(
self.name, redis_client=self._index.client
)
if existing_index.schema != self._index.schema:
raise ValueError(
f"Existing index {self.name} schema does not match the user provided schema for the semantic router. "
"If you wish to overwrite the index schema, set overwrite=True during initialization."
)
self._index.create(overwrite=overwrite, drop=False)

if not existed or overwrite:
# write the routes to Redis
Expand Down Expand Up @@ -153,7 +169,9 @@ def _add_routes(self, routes: List[Route]):
for route in routes:
# embed route references as a single batch
reference_vectors = self.vectorizer.embed_many(
[reference for reference in route.references], as_buffer=True
[reference for reference in route.references],
as_buffer=True,
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
)
# set route references
for i, reference in enumerate(route.references):
Expand Down Expand Up @@ -230,6 +248,7 @@ def _classify_route(
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
distance_threshold=distance_threshold,
return_fields=["route_name"],
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
)

aggregate_request = self._build_aggregate_request(
Expand Down Expand Up @@ -282,6 +301,7 @@ def _classify_multi_route(
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
distance_threshold=distance_threshold,
return_fields=["route_name"],
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
)
aggregate_request = self._build_aggregate_request(
vector_range_query, aggregation_method, max_k
Expand Down
9 changes: 4 additions & 5 deletions redisvl/extensions/session_manager/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,14 @@ def generate_id(cls, values):
)
return values

def to_dict(self) -> Dict:
def to_dict(self, dtype: Optional[str] = None) -> Dict:
data = self.dict(exclude_none=True)

# handle optional fields
if SESSION_VECTOR_FIELD_NAME in data:
data[SESSION_VECTOR_FIELD_NAME] = array_to_buffer(
data[SESSION_VECTOR_FIELD_NAME]
data[SESSION_VECTOR_FIELD_NAME], dtype # type: ignore[arg-type]
)

return data


Expand All @@ -80,7 +79,7 @@ def from_params(cls, name: str, prefix: str):
class SemanticSessionIndexSchema(IndexSchema):

@classmethod
def from_params(cls, name: str, prefix: str, vectorizer_dims: int):
def from_params(cls, name: str, prefix: str, vectorizer_dims: int, dtype: str):

return cls(
index={"name": name, "prefix": prefix}, # type: ignore
Expand All @@ -95,7 +94,7 @@ def from_params(cls, name: str, prefix: str, vectorizer_dims: int):
"type": "vector",
"attrs": {
"dims": vectorizer_dims,
"datatype": "float32",
"datatype": dtype,
"distance_metric": "cosine",
"algorithm": "flat",
},
Expand Down
Loading
Loading