Skip to content

Commit

Permalink
Merge pull request #256 from aurelio-labs/juan/fix-encoder-arrays
Browse files Browse the repository at this point in the history
fix: Added fix to encode documents within rolling window
  • Loading branch information
jamescalam authored Apr 27, 2024
2 parents 302fe17 + 3b78805 commit fd8cc15
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 20 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
- name: Pytest
env:
PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
make test
- name: Upload coverage to Codecov
Expand Down
4 changes: 2 additions & 2 deletions semantic_router/encoders/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Any, List

from pydantic.v1 import BaseModel, Field

Expand All @@ -11,5 +11,5 @@ class BaseEncoder(BaseModel):
class Config:
arbitrary_types_allowed = True

def __call__(self, docs: List[str]) -> List[List[float]]:
def __call__(self, docs: List[Any]) -> List[List[float]]:
raise NotImplementedError("Subclasses must implement this method")
2 changes: 1 addition & 1 deletion semantic_router/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class BaseIndex(BaseModel):
type: str = "base"

def add(
self, embeddings: List[List[float]], routes: List[str], utterances: List[str]
self, embeddings: List[List[float]], routes: List[str], utterances: List[Any]
):
"""
Add embeddings to the index.
Expand Down
41 changes: 34 additions & 7 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,27 @@ class PineconeIndex(BaseIndex):
ServerlessSpec: Any = Field(default=None, exclude=True)
namespace: Optional[str] = ""

def __init__(self, **data):
super().__init__(**data)
self._initialize_client()
def __init__(
self,
api_key: Optional[str] = None,
index_name: str = "index",
dimensions: Optional[int] = None,
metric: str = "cosine",
cloud: str = "aws",
region: str = "us-west-2",
host: str = "",
namespace: Optional[str] = "",
):
super().__init__()
self.index_name = index_name
self.dimensions = dimensions
self.metric = metric
self.cloud = cloud
self.region = region
self.host = host
self.namespace = namespace
self.type = "pinecone"
self.client = self._initialize_client()
self.index = self._init_index(force_create=True)
self.client = self._initialize_client(api_key=api_key)

def _initialize_client(self, api_key: Optional[str] = None):
try:
Expand All @@ -77,6 +92,18 @@ def _initialize_client(self, api_key: Optional[str] = None):
return Pinecone(**pinecone_args)

def _init_index(self, force_create: bool = False) -> Union[Any, None]:
"""Initializing the index can be done after the object has been created
to allow for the user to set the dimensions and other parameters.
If the index doesn't exist and the dimensions are given, the index will
be created. If the index exists, it will be returned. If the index doesn't
exist and the dimensions are not given, the index will not be created and
None will be returned.
:param force_create: If True, the index will be created even if the
dimensions are not given (which will raise an error).
:type force_create: bool, optional
"""
index_exists = self.index_name in self.client.list_indexes().names()
dimensions_given = self.dimensions is not None
if dimensions_given and not index_exists:
Expand All @@ -95,7 +122,7 @@ def _init_index(self, force_create: bool = False) -> Union[Any, None]:
time.sleep(0.5)
elif index_exists:
# if the index exists we just return it
index = self.client.Index(self.index_name, namespace=self.namespace)
index = self.client.Index(self.index_name)
# grab the dimensions from the index
self.dimensions = index.describe_index_stats()["dimension"]
elif force_create and not dimensions_given:
Expand Down Expand Up @@ -207,7 +234,7 @@ def get_routes(self) -> List[Tuple]:
def delete(self, route_name: str):
route_vec_ids = self._get_route_ids(route_name=route_name)
if self.index is not None:
self.index.delete(ids=route_vec_ids)
self.index.delete(ids=route_vec_ids, namespace=self.namespace)
else:
raise ValueError("Index is None, could not delete.")

Expand Down
8 changes: 4 additions & 4 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def from_config(cls, config: LayerConfig, index: Optional[BaseIndex] = None):
def add(self, route: Route):
logger.info(f"Adding `{route.name}` route")
# create embeddings
embeds = self.encoder(route.utterances) # type: ignore
embeds = self.encoder(route.utterances)
# if route has no score_threshold, use default
if route.score_threshold is None:
route.score_threshold = self.score_threshold
Expand All @@ -363,7 +363,7 @@ def add(self, route: Route):
self.index.add(
embeddings=embeds,
routes=[route.name] * len(route.utterances),
utterances=route.utterances, # type: ignore
utterances=route.utterances,
)
self.routes.append(route)

Expand Down Expand Up @@ -409,14 +409,14 @@ def _add_routes(self, routes: List[Route]):
all_utterances = [
utterance for route in routes for utterance in route.utterances
]
embedded_utterances = self.encoder(all_utterances) # type: ignore
embedded_utterances = self.encoder(all_utterances)
# create route array
route_names = [route.name for route in routes for _ in route.utterances]
# add everything to the index
self.index.add(
embeddings=embedded_utterances,
routes=route_names,
utterances=all_utterances, # type: ignore
utterances=all_utterances,
)

def _encode(self, text: str) -> Any:
Expand Down
27 changes: 21 additions & 6 deletions semantic_router/splitters/rolling_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,27 @@ def __call__(self, docs: List[str]) -> List[DocumentSplit]:
return splits

def _encode_documents(self, docs: List[str]) -> np.ndarray:
try:
embeddings = self.encoder(docs)
return np.array(embeddings)
except Exception as e:
logger.error(f"Error encoding documents {docs}: {e}")
raise
"""
Encodes a list of documents into embeddings. If the number of documents exceeds 2000,
the documents are split into batches to avoid overloading the encoder. OpenAI has a
limit of len(array) < 2048.
:param docs: List of text documents to be encoded.
:return: A numpy array of embeddings for the given documents.
"""
max_docs_per_batch = 2000
embeddings = []

for i in range(0, len(docs), max_docs_per_batch):
batch_docs = docs[i : i + max_docs_per_batch]
try:
batch_embeddings = self.encoder(batch_docs)
embeddings.extend(batch_embeddings)
except Exception as e:
logger.error(f"Error encoding documents {batch_docs}: {e}")
raise

return np.array(embeddings)

def _calculate_similarity_scores(self, encoded_docs: np.ndarray) -> List[float]:
raw_similarities = []
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import mock_open, patch

import pytest
import time

from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder
from semantic_router.index.local import LocalIndex
Expand Down Expand Up @@ -279,6 +280,23 @@ def test_query_filter_pinecone(self, openai_encoder, routes, index_cls):
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pineconeindex
)
time.sleep(10) # allow for index to be populated
query_result = route_layer(text="Hello", route_filter=["Route 1"]).name

try:
route_layer(text="Hello", route_filter=["Route 8"]).name
except ValueError:
assert True

assert query_result in ["Route 1"]

def test_namespace_pinecone_index(self, openai_encoder, routes, index_cls):
pinecone_api_key = os.environ["PINECONE_API_KEY"]
pineconeindex = PineconeIndex(api_key=pinecone_api_key, namespace="test")
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pineconeindex
)
time.sleep(10) # allow for index to be populated
query_result = route_layer(text="Hello", route_filter=["Route 1"]).name

try:
Expand Down

0 comments on commit fd8cc15

Please sign in to comment.