Skip to content

Commit

Permalink
Merge pull request #82 from aurelio-labs/ashraq/custom-score-threshold
Browse files Browse the repository at this point in the history
feat: Move score_threshold to encoders
  • Loading branch information
jamescalam committed Jan 7, 2024
2 parents 603e4f5 + 59713d5 commit 99ce319
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 40 deletions.
1 change: 1 addition & 0 deletions semantic_router/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

class BaseEncoder(BaseModel):
name: str
score_threshold: float
type: str = Field(default="base")

class Config:
Expand Down
6 changes: 4 additions & 2 deletions semantic_router/encoders/bm25.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from typing import Any

from semantic_router.encoders import BaseEncoder
from semantic_router.utils.logger import logger


class BM25Encoder(BaseEncoder):
model: Any | None = None
idx_mapping: dict[int, int] | None = None
type: str = "sparse"

def __init__(self, name: str = "bm25"):
super().__init__(name=name)
def __init__(self, name: str = "bm25", score_threshold: float = 0.82):
super().__init__(name=name, score_threshold=score_threshold)
try:
from pinecone_text.sparse import BM25Encoder as encoder
except ImportError:
raise ImportError(
"Please install pinecone-text to use BM25Encoder. "
"You can install it with: `pip install semantic-router[hybrid]`"
)
logger.info("Downloading and initializing BM25 model parameters.")
self.model = encoder.default()

params = self.model.get_params()
Expand Down
3 changes: 2 additions & 1 deletion semantic_router/encoders/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ def __init__(
self,
name: str | None = None,
cohere_api_key: str | None = None,
score_threshold: float = 0.3,
):
if name is None:
name = os.getenv("COHERE_MODEL_NAME", "embed-english-v3.0")
super().__init__(name=name)
super().__init__(name=name, score_threshold=score_threshold)
cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY")
if cohere_api_key is None:
raise ValueError("Cohere API key cannot be 'None'.")
Expand Down
6 changes: 4 additions & 2 deletions semantic_router/encoders/fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ class FastEmbedEncoder(BaseEncoder):
threads: Optional[int] = None
_client: Any = PrivateAttr()

def __init__(self, **data):
super().__init__(**data)
def __init__(
self, score_threshold: float = 0.5, **data
): # TODO default score_threshold not thoroughly tested, should optimize
super().__init__(score_threshold=score_threshold, **data)
self._client = self._initialize_client()

def _initialize_client(self):
Expand Down
3 changes: 2 additions & 1 deletion semantic_router/encoders/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ def __init__(
self,
name: str | None = None,
openai_api_key: str | None = None,
score_threshold: float = 0.82,
):
if name is None:
name = os.getenv("OPENAI_MODEL_NAME", "text-embedding-ada-002")
super().__init__(name=name)
super().__init__(name=name, score_threshold=score_threshold)
api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
if api_key is None:
raise ValueError("OpenAI API key cannot be 'None'.")
Expand Down
12 changes: 2 additions & 10 deletions semantic_router/hybrid_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from semantic_router.encoders import (
BaseEncoder,
BM25Encoder,
CohereEncoder,
OpenAIEncoder,
)
from semantic_router.route import Route
from semantic_router.utils.logger import logger
Expand All @@ -15,21 +13,15 @@ class HybridRouteLayer:
index = None
sparse_index = None
categories = None
score_threshold = 0.82
score_threshold: float

def __init__(
self, encoder: BaseEncoder, routes: list[Route] = [], alpha: float = 0.3
):
self.encoder = encoder
self.score_threshold = self.encoder.score_threshold
self.sparse_encoder = BM25Encoder()
self.alpha = alpha
# decide on default threshold based on encoder
if isinstance(encoder, OpenAIEncoder):
self.score_threshold = 0.82
elif isinstance(encoder, CohereEncoder):
self.score_threshold = 0.3
else:
self.score_threshold = 0.82
# if routes list has been passed, we initialize index now
if routes:
# initialize index now
Expand Down
31 changes: 12 additions & 19 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@
import numpy as np
import yaml

from semantic_router.encoders import (
BaseEncoder,
CohereEncoder,
FastEmbedEncoder,
OpenAIEncoder,
)
from semantic_router.encoders import BaseEncoder, OpenAIEncoder
from semantic_router.linear import similarity_matrix, top_scores
from semantic_router.llms import BaseLLM, OpenAILLM
from semantic_router.route import Route
Expand Down Expand Up @@ -154,7 +149,8 @@ def remove(self, name: str):
class RouteLayer:
index: np.ndarray | None = None
categories: np.ndarray | None = None
score_threshold: float = 0.82
score_threshold: float
encoder: BaseEncoder

def __init__(
self,
Expand All @@ -165,20 +161,17 @@ def __init__(
logger.info("Initializing RouteLayer")
self.index = None
self.categories = None
self.encoder = encoder if encoder is not None else CohereEncoder()
if encoder is None:
logger.warning(
"No encoder provided. Using default OpenAIEncoder. Ensure "
"that you have set OPENAI_API_KEY in your environment."
)
self.encoder = OpenAIEncoder()
else:
self.encoder = encoder
self.llm = llm
self.routes: list[Route] = routes if routes is not None else []
# decide on default threshold based on encoder
# TODO move defaults to the encoder objects and extract from there
if isinstance(encoder, OpenAIEncoder):
self.score_threshold = 0.82
elif isinstance(encoder, CohereEncoder):
self.score_threshold = 0.3
elif isinstance(encoder, FastEmbedEncoder):
# TODO default not thoroughly tested, should optimize
self.score_threshold = 0.5
else:
self.score_threshold = 0.82
self.score_threshold = self.encoder.score_threshold
# if routes list has been passed, we initialize index now
if len(self.routes) > 0:
# initialize index now
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/encoders/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
class TestBaseEncoder:
@pytest.fixture
def base_encoder(self):
return BaseEncoder(name="TestEncoder")
return BaseEncoder(name="TestEncoder", score_threshold=0.5)

def test_base_encoder_initialization(self, base_encoder):
assert base_encoder.name == "TestEncoder", "Initialization of name failed"
assert base_encoder.score_threshold == 0.5

def test_base_encoder_call_method_not_implemented(self, base_encoder):
with pytest.raises(NotImplementedError):
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/test_hybrid_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def mock_encoder_call(utterances):

@pytest.fixture
def base_encoder():
return BaseEncoder(name="test-encoder")
return BaseEncoder(name="test-encoder", score_threshold=0.5)


@pytest.fixture
Expand All @@ -46,6 +46,7 @@ class TestHybridRouteLayer:
def test_initialization(self, openai_encoder, routes):
route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes)
assert route_layer.index is not None and route_layer.categories is not None
assert openai_encoder.score_threshold == 0.82
assert route_layer.score_threshold == 0.82
assert len(route_layer.index) == 5
assert len(set(route_layer.categories)) == 2
Expand Down Expand Up @@ -112,7 +113,8 @@ def test_pass_threshold(self, openai_encoder):

def test_failover_score_threshold(self, base_encoder):
route_layer = HybridRouteLayer(encoder=base_encoder)
assert route_layer.score_threshold == 0.82
assert base_encoder.score_threshold == 0.50
assert route_layer.score_threshold == 0.50


# Add more tests for edge cases and error handling as needed.
12 changes: 10 additions & 2 deletions tests/unit/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def layer_yaml():

@pytest.fixture
def base_encoder():
return BaseEncoder(name="test-encoder")
return BaseEncoder(name="test-encoder", score_threshold=0.5)


@pytest.fixture
Expand Down Expand Up @@ -103,6 +103,7 @@ def dynamic_routes():
class TestRouteLayer:
def test_initialization(self, openai_encoder, routes):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
assert openai_encoder.score_threshold == 0.82
assert route_layer.score_threshold == 0.82
assert len(route_layer.index) if route_layer.index is not None else 0 == 5
assert (
Expand All @@ -113,14 +114,21 @@ def test_initialization(self, openai_encoder, routes):

def test_initialization_different_encoders(self, cohere_encoder, openai_encoder):
route_layer_cohere = RouteLayer(encoder=cohere_encoder)
assert cohere_encoder.score_threshold == 0.3
assert route_layer_cohere.score_threshold == 0.3
route_layer_openai = RouteLayer(encoder=openai_encoder)
assert route_layer_openai.score_threshold == 0.82

def test_initialization_no_encoder(self, openai_encoder):
os.environ["OPENAI_API_KEY"] = "test_api_key"
route_layer_none = RouteLayer(encoder=None)
assert route_layer_none.score_threshold == openai_encoder.score_threshold

def test_initialization_dynamic_route(self, cohere_encoder, openai_encoder):
route_layer_cohere = RouteLayer(encoder=cohere_encoder)
assert route_layer_cohere.score_threshold == 0.3
route_layer_openai = RouteLayer(encoder=openai_encoder)
assert openai_encoder.score_threshold == 0.82
assert route_layer_openai.score_threshold == 0.82

def test_add_route(self, openai_encoder):
Expand Down Expand Up @@ -186,7 +194,7 @@ def test_pass_threshold(self, openai_encoder):

def test_failover_score_threshold(self, base_encoder):
route_layer = RouteLayer(encoder=base_encoder)
assert route_layer.score_threshold == 0.82
assert route_layer.score_threshold == 0.5

def test_json(self, openai_encoder, routes):
with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as temp:
Expand Down

0 comments on commit 99ce319

Please sign in to comment.