diff --git a/README.md b/README.md
index da3fe685..fe5db343 100644
--- a/README.md
+++ b/README.md
@@ -12,7 +12,7 @@
-Semantic Router is a superfast decision layer for your LLMs and agents. Rather than waiting for slow LLM generations to make tool-use decisions, we use the magic of semantic vector space to make those decisions — _routing_ our requests using _semantic_ meaning.
+Semantic Router is a superfast decision-making layer for your LLMs and agents. Rather than waiting for slow LLM generations to make tool-use decisions, we use the magic of semantic vector space to make those decisions — _routing_ our requests using _semantic_ meaning.
## Quickstart
@@ -22,7 +22,9 @@ To get started with _semantic-router_ we install it like so:
pip install -qU semantic-router
```
-We begin by defining a set of `Decision` objects. These are the decision paths that the semantic router can decide to use, let's try two simple decisions for now — one for talk on _politics_ and another for _chitchat_:
+❗️ _If wanting to use local embeddings you can use `FastEmbedEncoder` (`pip install -qU semantic-router[fastembed]`). To use the `HybridRouteLayer` you must `pip install -qU semantic-router[hybrid]`._
+
+We begin by defining a set of `Route` objects. These are the decision paths that the semantic router can decide to use, let's try two simple routes for now — one for talk on _politics_ and another for _chitchat_:
```python
from semantic_router import Route
@@ -56,7 +58,7 @@ chitchat = Route(
routes = [politics, chitchat]
```
-We have our decisions ready, now we initialize an embedding / encoder model. We currently support a `CohereEncoder` and `OpenAIEncoder` — more encoders will be added soon. To initialize them we do:
+We have our routes ready, now we initialize an embedding / encoder model. We currently support a `CohereEncoder` and `OpenAIEncoder` — more encoders will be added soon. To initialize them we do:
```python
import os
@@ -71,18 +73,18 @@ os.environ["OPENAI_API_KEY"] = ""
encoder = OpenAIEncoder()
```
-With our `decisions` and `encoder` defined we now create a `DecisionLayer`. The decision layer handles our semantic decision making.
+With our `routes` and `encoder` defined we now create a `RouteLayer`. The route layer handles our semantic decision making.
```python
from semantic_router.layer import RouteLayer
-dl = RouteLayer(encoder=encoder, routes=routes)
+rl = RouteLayer(encoder=encoder, routes=routes)
```
-We can now use our decision layer to make super fast decisions based on user queries. Let's try with two queries that should trigger our decisions:
+We can now use our route layer to make super fast decisions based on user queries. Let's try with two queries that should trigger our route decisions:
```python
-dl("don't you love politics?").name
+rl("don't you love politics?").name
```
```
@@ -92,7 +94,7 @@ dl("don't you love politics?").name
Correct decision, let's try another:
```python
-dl("how's the weather today?").name
+rl("how's the weather today?").name
```
```
@@ -102,14 +104,14 @@ dl("how's the weather today?").name
We get both decisions correct! Now lets try sending an unrelated query:
```python
-dl("I'm interested in learning about llama 2").name
+rl("I'm interested in learning about llama 2").name
```
```
[Out]:
```
-In this case, no decision could be made as we had no matches — so our decision layer returned `None`!
+In this case, no decision could be made as we had no matches — so our route layer returned `None`!
## 📚 [Resources](https://github.com/aurelio-labs/semantic-router/tree/main/docs)
diff --git a/docs/02-dynamic-routes.ipynb b/docs/02-dynamic-routes.ipynb
index d8078cb2..c695838e 100644
--- a/docs/02-dynamic-routes.ipynb
+++ b/docs/02-dynamic-routes.ipynb
@@ -36,7 +36,7 @@
"metadata": {},
"outputs": [],
"source": [
- "!pip install -qU semantic-router==0.0.14"
+ "!pip install -qU semantic-router==0.0.15"
]
},
{
@@ -64,17 +64,7 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/jamesbriggs/opt/anaconda3/envs/decision-layer/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n",
- "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"from semantic_router import Route\n",
"\n",
@@ -102,16 +92,23 @@
"routes = [politics, chitchat]"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We initialize our `RouteLayer` with our `encoder` and `routes`. We can use popular encoder APIs like `CohereEncoder` and `OpenAIEncoder`, or local alternatives like `FastEmbedEncoder`."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "\u001b[32m2023-12-28 19:19:39 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n"
+ "\u001b[32m2024-01-07 15:23:12 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n"
]
}
],
@@ -119,13 +116,21 @@
"import os\n",
"from getpass import getpass\n",
"from semantic_router import RouteLayer\n",
+ "from semantic_router.encoders import CohereEncoder, OpenAIEncoder\n",
"\n",
"# dashboard.cohere.ai\n",
- "os.environ[\"COHERE_API_KEY\"] = os.getenv(\"COHERE_API_KEY\") or getpass(\n",
- " \"Enter Cohere API Key: \"\n",
+ "# os.environ[\"COHERE_API_KEY\"] = os.getenv(\"COHERE_API_KEY\") or getpass(\n",
+ "# \"Enter Cohere API Key: \"\n",
+ "# )\n",
+ "# platform.openai.com\n",
+ "os.environ[\"OPENAI_API_KEY\"] = os.getenv(\"OPENAI_API_KEY\") or getpass(\n",
+ " \"Enter OpenAI API Key: \"\n",
")\n",
"\n",
- "rl = RouteLayer(routes=routes)"
+ "# encoder = CohereEncoder()\n",
+ "encoder = OpenAIEncoder()\n",
+ "\n",
+ "rl = RouteLayer(encoder=encoder, routes=routes)"
]
},
{
@@ -137,7 +142,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -146,7 +151,7 @@
"RouteChoice(name='chitchat', function_call=None)"
]
},
- "execution_count": 5,
+ "execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@@ -171,7 +176,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -193,16 +198,16 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "'13:19'"
+ "'09:23'"
]
},
- "execution_count": 7,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -220,7 +225,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -232,7 +237,7 @@
" 'output': \"\"}"
]
},
- "execution_count": 8,
+ "execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -253,7 +258,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -277,16 +282,14 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
- "name": "stdout",
+ "name": "stderr",
"output_type": "stream",
"text": [
- "Adding route `get_time`\n",
- "Adding route to categories\n",
- "Adding route to index\n"
+ "\u001b[32m2024-01-07 15:23:16 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n"
]
}
],
@@ -303,31 +306,32 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "\u001b[32m2023-12-28 19:21:58 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n"
+ "\u001b[33m2024-01-07 15:23:17 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n",
+ "\u001b[32m2024-01-07 15:23:17 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
- "RouteChoice(name='get_time', function_call={'timezone': 'America/New_York'})"
+ "RouteChoice(name='get_time', function_call={'timezone': 'new york city'})"
]
},
- "execution_count": 11,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "# https://openrouter.ai/keys\n",
- "os.environ[\"OPENROUTER_API_KEY\"] = os.getenv(\"OPENROUTER_API_KEY\") or getpass(\n",
- " \"Enter OpenRouter API Key: \"\n",
+ "# https://platform.openai.com/\n",
+ "os.environ[\"OPENAI_API_KEY\"] = os.getenv(\"OPENAI_API_KEY\") or getpass(\n",
+ " \"Enter OpenAI API Key: \"\n",
")\n",
"\n",
"rl(\"what is the time in new york city?\")"
diff --git a/pyproject.toml b/pyproject.toml
index d3561c64..b24ed4f3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "semantic-router"
-version = "0.0.14"
+version = "0.0.15"
description = "Super fast semantic router for AI decision making"
authors = [
"James Briggs ",
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 6d85508c..72f8a8f0 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -7,10 +7,11 @@
from semantic_router.encoders import (
BaseEncoder,
CohereEncoder,
- OpenAIEncoder,
FastEmbedEncoder,
+ OpenAIEncoder,
)
from semantic_router.linear import similarity_matrix, top_scores
+from semantic_router.llms import BaseLLM, OpenAILLM
from semantic_router.route import Route
from semantic_router.schema import Encoder, EncoderType, RouteChoice
from semantic_router.utils.logger import logger
@@ -156,12 +157,16 @@ class RouteLayer:
score_threshold: float = 0.82
def __init__(
- self, encoder: BaseEncoder | None = None, routes: list[Route] | None = None
+ self,
+ encoder: BaseEncoder | None = None,
+ llm: BaseLLM | None = None,
+ routes: list[Route] | None = None,
):
logger.info("Initializing RouteLayer")
self.index = None
self.categories = None
self.encoder = encoder if encoder is not None else CohereEncoder()
+ self.llm = llm
self.routes: list[Route] = routes if routes is not None else []
# decide on default threshold based on encoder
# TODO move defaults to the encoder objects and extract from there
@@ -186,6 +191,17 @@ def __call__(self, text: str) -> RouteChoice:
if passed:
# get chosen route object
route = [route for route in self.routes if route.name == top_class][0]
+ if route.function_schema and not isinstance(route.llm, BaseLLM):
+ if not self.llm:
+ logger.warning(
+ "No LLM provided for dynamic route, will use OpenAI LLM "
+ "default. Ensure API key is set in OPENAI_API_KEY environment "
+ "variable."
+ )
+ self.llm = OpenAILLM()
+ route.llm = self.llm
+ else:
+ route.llm = self.llm
return route(text)
else:
# if no route passes threshold, return empty route choice
@@ -216,24 +232,20 @@ def from_config(cls, config: LayerConfig):
return cls(encoder=encoder, routes=config.routes)
def add(self, route: Route):
- print(f"Adding route `{route.name}`")
+ logger.info(f"Adding `{route.name}` route")
# create embeddings
embeds = self.encoder(route.utterances)
# create route array
if self.categories is None:
- print("Initializing categories array")
self.categories = np.array([route.name] * len(embeds))
else:
- print("Adding route to categories")
str_arr = np.array([route.name] * len(embeds))
self.categories = np.concatenate([self.categories, str_arr])
# create utterance array (the index)
if self.index is None:
- print("Initializing index array")
self.index = np.array(embeds)
else:
- print("Adding route to index")
embed_arr = np.array(embeds)
self.index = np.concatenate([self.index, embed_arr])
# add route to routes list
diff --git a/semantic_router/llms/__init__.py b/semantic_router/llms/__init__.py
new file mode 100644
index 00000000..e5aedc85
--- /dev/null
+++ b/semantic_router/llms/__init__.py
@@ -0,0 +1,6 @@
+from semantic_router.llms.base import BaseLLM
+from semantic_router.llms.cohere import CohereLLM
+from semantic_router.llms.openai import OpenAILLM
+from semantic_router.llms.openrouter import OpenRouterLLM
+
+__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM"]
diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py
new file mode 100644
index 00000000..51db1fd0
--- /dev/null
+++ b/semantic_router/llms/base.py
@@ -0,0 +1,13 @@
+from pydantic import BaseModel
+
+from semantic_router.schema import Message
+
+
+class BaseLLM(BaseModel):
+ name: str
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def __call__(self, messages: list[Message]) -> str | None:
+ raise NotImplementedError("Subclasses must implement this method")
diff --git a/semantic_router/llms/cohere.py b/semantic_router/llms/cohere.py
new file mode 100644
index 00000000..77581700
--- /dev/null
+++ b/semantic_router/llms/cohere.py
@@ -0,0 +1,45 @@
+import os
+
+import cohere
+
+from semantic_router.llms import BaseLLM
+from semantic_router.schema import Message
+
+
+class CohereLLM(BaseLLM):
+ client: cohere.Client | None = None
+
+ def __init__(
+ self,
+ name: str | None = None,
+ cohere_api_key: str | None = None,
+ ):
+ if name is None:
+ name = os.getenv("COHERE_CHAT_MODEL_NAME", "command")
+ super().__init__(name=name)
+ cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY")
+ if cohere_api_key is None:
+ raise ValueError("Cohere API key cannot be 'None'.")
+ try:
+ self.client = cohere.Client(cohere_api_key)
+ except Exception as e:
+ raise ValueError(f"Cohere API client failed to initialize. Error: {e}")
+
+ def __call__(self, messages: list[Message]) -> str:
+ if self.client is None:
+ raise ValueError("Cohere client is not initialized.")
+ try:
+ completion = self.client.chat(
+ model=self.name,
+ chat_history=[m.to_cohere() for m in messages[:-1]],
+ message=messages[-1].content,
+ )
+
+ output = completion.text
+
+ if not output:
+ raise Exception("No output generated")
+ return output
+
+ except Exception as e:
+ raise ValueError(f"Cohere API call failed. Error: {e}")
diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py
new file mode 100644
index 00000000..43ddd642
--- /dev/null
+++ b/semantic_router/llms/openai.py
@@ -0,0 +1,53 @@
+import os
+
+import openai
+
+from semantic_router.llms import BaseLLM
+from semantic_router.schema import Message
+from semantic_router.utils.logger import logger
+
+
+class OpenAILLM(BaseLLM):
+ client: openai.OpenAI | None
+ temperature: float | None
+ max_tokens: int | None
+
+ def __init__(
+ self,
+ name: str | None = None,
+ openai_api_key: str | None = None,
+ temperature: float = 0.01,
+ max_tokens: int = 200,
+ ):
+ if name is None:
+ name = os.getenv("OPENAI_CHAT_MODEL_NAME", "gpt-3.5-turbo")
+ super().__init__(name=name)
+ api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
+ if api_key is None:
+ raise ValueError("OpenAI API key cannot be 'None'.")
+ try:
+ self.client = openai.OpenAI(api_key=api_key)
+ except Exception as e:
+ raise ValueError(f"OpenAI API client failed to initialize. Error: {e}")
+ self.temperature = temperature
+ self.max_tokens = max_tokens
+
+ def __call__(self, messages: list[Message]) -> str:
+ if self.client is None:
+ raise ValueError("OpenAI client is not initialized.")
+ try:
+ completion = self.client.chat.completions.create(
+ model=self.name,
+ messages=[m.to_openai() for m in messages],
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ output = completion.choices[0].message.content
+
+ if not output:
+ raise Exception("No output generated")
+ return output
+ except Exception as e:
+ logger.error(f"LLM error: {e}")
+ raise Exception(f"LLM error: {e}")
diff --git a/semantic_router/llms/openrouter.py b/semantic_router/llms/openrouter.py
new file mode 100644
index 00000000..587eeb12
--- /dev/null
+++ b/semantic_router/llms/openrouter.py
@@ -0,0 +1,58 @@
+import os
+
+import openai
+
+from semantic_router.llms import BaseLLM
+from semantic_router.schema import Message
+from semantic_router.utils.logger import logger
+
+
+class OpenRouterLLM(BaseLLM):
+ client: openai.OpenAI | None
+ base_url: str | None
+ temperature: float | None
+ max_tokens: int | None
+
+ def __init__(
+ self,
+ name: str | None = None,
+ openrouter_api_key: str | None = None,
+ base_url: str = "https://openrouter.ai/api/v1",
+ temperature: float = 0.01,
+ max_tokens: int = 200,
+ ):
+ if name is None:
+ name = os.getenv(
+ "OPENROUTER_CHAT_MODEL_NAME", "mistralai/mistral-7b-instruct"
+ )
+ super().__init__(name=name)
+ self.base_url = base_url
+ api_key = openrouter_api_key or os.getenv("OPENROUTER_API_KEY")
+ if api_key is None:
+ raise ValueError("OpenRouter API key cannot be 'None'.")
+ try:
+ self.client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
+ except Exception as e:
+ raise ValueError(f"OpenRouter API client failed to initialize. Error: {e}")
+ self.temperature = temperature
+ self.max_tokens = max_tokens
+
+ def __call__(self, messages: list[Message]) -> str:
+ if self.client is None:
+ raise ValueError("OpenRouter client is not initialized.")
+ try:
+ completion = self.client.chat.completions.create(
+ model=self.name,
+ messages=[m.to_openai() for m in messages],
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ output = completion.choices[0].message.content
+
+ if not output:
+ raise Exception("No output generated")
+ return output
+ except Exception as e:
+ logger.error(f"LLM error: {e}")
+ raise Exception(f"LLM error: {e}")
diff --git a/semantic_router/route.py b/semantic_router/route.py
index 3fc717ef..0d8269f0 100644
--- a/semantic_router/route.py
+++ b/semantic_router/route.py
@@ -4,9 +4,9 @@
from pydantic import BaseModel
-from semantic_router.schema import RouteChoice
+from semantic_router.llms import BaseLLM
+from semantic_router.schema import Message, RouteChoice
from semantic_router.utils import function_call
-from semantic_router.utils.llm import llm
from semantic_router.utils.logger import logger
@@ -43,12 +43,18 @@ class Route(BaseModel):
utterances: list[str]
description: str | None = None
function_schema: dict[str, Any] | None = None
+ llm: BaseLLM | None = None
def __call__(self, query: str) -> RouteChoice:
if self.function_schema:
+ if not self.llm:
+ raise ValueError(
+ "LLM is required for dynamic routes. Please ensure the `llm` "
+ "attribute is set."
+ )
# if a function schema is provided we generate the inputs
extracted_inputs = function_call.extract_function_inputs(
- query=query, function_schema=self.function_schema
+ query=query, llm=self.llm, function_schema=self.function_schema
)
func_call = extracted_inputs
else:
@@ -60,16 +66,16 @@ def to_dict(self):
return self.dict()
@classmethod
- def from_dict(cls, data: dict):
+ def from_dict(cls, data: dict[str, Any]):
return cls(**data)
@classmethod
- def from_dynamic_route(cls, entity: Union[BaseModel, Callable]):
+ def from_dynamic_route(cls, llm: BaseLLM, entity: Union[BaseModel, Callable]):
"""
Generate a dynamic Route object from a function or Pydantic model using LLM
"""
schema = function_call.get_schema(item=entity)
- dynamic_route = cls._generate_dynamic_route(function_schema=schema)
+ dynamic_route = cls._generate_dynamic_route(llm=llm, function_schema=schema)
dynamic_route.function_schema = schema
return dynamic_route
@@ -86,7 +92,7 @@ def _parse_route_config(cls, config: str) -> str:
raise ValueError("No tags found in the output.")
@classmethod
- def _generate_dynamic_route(cls, function_schema: dict[str, Any]):
+ def _generate_dynamic_route(cls, llm: BaseLLM, function_schema: dict[str, Any]):
logger.info("Generating dynamic route...")
prompt = f"""
@@ -114,7 +120,8 @@ def _generate_dynamic_route(cls, function_schema: dict[str, Any]):
{function_schema}
"""
- output = llm(prompt)
+ llm_input = [Message(role="user", content=prompt)]
+ output = llm(llm_input)
if not output:
raise Exception("No output generated for dynamic route")
@@ -123,5 +130,7 @@ def _generate_dynamic_route(cls, function_schema: dict[str, Any]):
logger.info(f"Generated route config:\n{route_config}")
if is_valid(route_config):
- return Route.from_dict(json.loads(route_config))
+ route_config_dict = json.loads(route_config)
+ route_config_dict["llm"] = llm
+ return Route.from_dict(route_config_dict)
raise Exception("No config generated")
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 64480355..5e94c23b 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -6,8 +6,8 @@
from semantic_router.encoders import (
BaseEncoder,
CohereEncoder,
- OpenAIEncoder,
FastEmbedEncoder,
+ OpenAIEncoder,
)
from semantic_router.utils.splitters import semantic_splitter
@@ -52,6 +52,14 @@ class Message(BaseModel):
role: str
content: str
+ def to_openai(self):
+ if self.role.lower() not in ["user", "assistant", "system"]:
+ raise ValueError("Role must be either 'user', 'assistant' or 'system'")
+ return {"role": self.role, "content": self.content}
+
+ def to_cohere(self):
+ return {"role": self.role, "message": self.content}
+
class Conversation(BaseModel):
messages: list[Message]
diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py
index 4319a8ec..cedd9b6e 100644
--- a/semantic_router/utils/function_call.py
+++ b/semantic_router/utils/function_call.py
@@ -4,8 +4,8 @@
from pydantic import BaseModel
-from semantic_router.schema import RouteChoice
-from semantic_router.utils.llm import llm
+from semantic_router.llms import BaseLLM
+from semantic_router.schema import Message, RouteChoice
from semantic_router.utils.logger import logger
@@ -41,7 +41,9 @@ def get_schema(item: Union[BaseModel, Callable]) -> dict[str, Any]:
return schema
-def extract_function_inputs(query: str, function_schema: dict[str, Any]) -> dict:
+def extract_function_inputs(
+ query: str, llm: BaseLLM, function_schema: dict[str, Any]
+) -> dict:
logger.info("Extracting function input...")
prompt = f"""
@@ -72,8 +74,8 @@ def extract_function_inputs(query: str, function_schema: dict[str, Any]) -> dict
schema: {function_schema}
Result:
"""
-
- output = llm(prompt)
+ llm_input = [Message(role="user", content=prompt)]
+ output = llm(llm_input)
if not output:
raise Exception("No output generated for extract function input")
@@ -107,7 +109,9 @@ def is_valid_inputs(inputs: dict[str, Any], function_schema: dict[str, Any]) ->
# TODO: Add route layer object to the input, solve circular import issue
-async def route_and_execute(query: str, functions: list[Callable], layer) -> Any:
+async def route_and_execute(
+ query: str, llm: BaseLLM, functions: list[Callable], layer
+) -> Any:
route_choice: RouteChoice = layer(query)
for function in functions:
@@ -116,4 +120,5 @@ async def route_and_execute(query: str, functions: list[Callable], layer) -> Any
return function(**route_choice.function_call)
logger.warning("No function found, calling LLM.")
- return llm(query)
+ llm_input = [Message(role="user", content=query)]
+ return llm(llm_input)
diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py
new file mode 100644
index 00000000..df78d8f5
--- /dev/null
+++ b/tests/unit/llms/test_llm_base.py
@@ -0,0 +1,16 @@
+import pytest
+
+from semantic_router.llms import BaseLLM
+
+
+class TestBaseLLM:
+ @pytest.fixture
+ def base_llm(self):
+ return BaseLLM(name="TestLLM")
+
+ def test_base_llm_initialization(self, base_llm):
+ assert base_llm.name == "TestLLM", "Initialization of name failed"
+
+ def test_base_llm_call_method_not_implemented(self, base_llm):
+ with pytest.raises(NotImplementedError):
+ base_llm("test")
diff --git a/tests/unit/llms/test_llm_cohere.py b/tests/unit/llms/test_llm_cohere.py
new file mode 100644
index 00000000..aaf8a7e5
--- /dev/null
+++ b/tests/unit/llms/test_llm_cohere.py
@@ -0,0 +1,52 @@
+import pytest
+
+from semantic_router.llms import CohereLLM
+from semantic_router.schema import Message
+
+
+@pytest.fixture
+def cohere_llm(mocker):
+ mocker.patch("cohere.Client")
+ return CohereLLM(cohere_api_key="test_api_key")
+
+
+class TestCohereLLM:
+ def test_initialization_with_api_key(self, cohere_llm):
+ assert cohere_llm.client is not None, "Client should be initialized"
+ assert cohere_llm.name == "command", "Default name not set correctly"
+
+ def test_initialization_without_api_key(self, mocker, monkeypatch):
+ monkeypatch.delenv("COHERE_API_KEY", raising=False)
+ mocker.patch("cohere.Client")
+ with pytest.raises(ValueError):
+ CohereLLM()
+
+ def test_call_method(self, cohere_llm, mocker):
+ mock_llm = mocker.MagicMock()
+ mock_llm.text = "test"
+ cohere_llm.client.chat.return_value = mock_llm
+
+ llm_input = [Message(role="user", content="test")]
+ result = cohere_llm(llm_input)
+ assert isinstance(result, str), "Result should be a str"
+ cohere_llm.client.chat.assert_called_once()
+
+ def test_raises_value_error_if_cohere_client_fails_to_initialize(self, mocker):
+ mocker.patch(
+ "cohere.Client", side_effect=Exception("Failed to initialize client")
+ )
+ with pytest.raises(ValueError):
+ CohereLLM(cohere_api_key="test_api_key")
+
+ def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker):
+ mocker.patch("cohere.Client", return_value=None)
+ llm = CohereLLM(cohere_api_key="test_api_key")
+ with pytest.raises(ValueError):
+ llm("test")
+
+ def test_call_method_raises_error_on_api_failure(self, cohere_llm, mocker):
+ mocker.patch.object(
+ cohere_llm.client, "__call__", side_effect=Exception("API call failed")
+ )
+ with pytest.raises(ValueError):
+ cohere_llm("test")
diff --git a/tests/unit/llms/test_llm_openai.py b/tests/unit/llms/test_llm_openai.py
new file mode 100644
index 00000000..2f1171db
--- /dev/null
+++ b/tests/unit/llms/test_llm_openai.py
@@ -0,0 +1,56 @@
+import pytest
+
+from semantic_router.llms import OpenAILLM
+from semantic_router.schema import Message
+
+
+@pytest.fixture
+def openai_llm(mocker):
+ mocker.patch("openai.Client")
+ return OpenAILLM(openai_api_key="test_api_key")
+
+
+class TestOpenAILLM:
+ def test_openai_llm_init_with_api_key(self, openai_llm):
+ assert openai_llm.client is not None, "Client should be initialized"
+ assert openai_llm.name == "gpt-3.5-turbo", "Default name not set correctly"
+
+ def test_openai_llm_init_success(self, mocker):
+ mocker.patch("os.getenv", return_value="fake-api-key")
+ llm = OpenAILLM()
+ assert llm.client is not None
+
+ def test_openai_llm_init_without_api_key(self, mocker):
+ mocker.patch("os.getenv", return_value=None)
+ with pytest.raises(ValueError) as _:
+ OpenAILLM()
+
+ def test_openai_llm_call_uninitialized_client(self, openai_llm):
+ # Set the client to None to simulate an uninitialized client
+ openai_llm.client = None
+ with pytest.raises(ValueError) as e:
+ llm_input = [Message(role="user", content="test")]
+ openai_llm(llm_input)
+ assert "OpenAI client is not initialized." in str(e.value)
+
+ def test_openai_llm_init_exception(self, mocker):
+ mocker.patch("os.getenv", return_value="fake-api-key")
+ mocker.patch("openai.OpenAI", side_effect=Exception("Initialization error"))
+ with pytest.raises(ValueError) as e:
+ OpenAILLM()
+ assert (
+ "OpenAI API client failed to initialize. Error: Initialization error"
+ in str(e.value)
+ )
+
+ def test_openai_llm_call_success(self, openai_llm, mocker):
+ mock_completion = mocker.MagicMock()
+ mock_completion.choices[0].message.content = "test"
+
+ mocker.patch("os.getenv", return_value="fake-api-key")
+ mocker.patch.object(
+ openai_llm.client.chat.completions, "create", return_value=mock_completion
+ )
+ llm_input = [Message(role="user", content="test")]
+ output = openai_llm(llm_input)
+ assert output == "test"
diff --git a/tests/unit/llms/test_llm_openrouter.py b/tests/unit/llms/test_llm_openrouter.py
new file mode 100644
index 00000000..9b1ee150
--- /dev/null
+++ b/tests/unit/llms/test_llm_openrouter.py
@@ -0,0 +1,60 @@
+import pytest
+
+from semantic_router.llms import OpenRouterLLM
+from semantic_router.schema import Message
+
+
+@pytest.fixture
+def openrouter_llm(mocker):
+ mocker.patch("openai.Client")
+ return OpenRouterLLM(openrouter_api_key="test_api_key")
+
+
+class TestOpenRouterLLM:
+ def test_openrouter_llm_init_with_api_key(self, openrouter_llm):
+ assert openrouter_llm.client is not None, "Client should be initialized"
+ assert (
+ openrouter_llm.name == "mistralai/mistral-7b-instruct"
+ ), "Default name not set correctly"
+
+ def test_openrouter_llm_init_success(self, mocker):
+ mocker.patch("os.getenv", return_value="fake-api-key")
+ llm = OpenRouterLLM()
+ assert llm.client is not None
+
+ def test_openrouter_llm_init_without_api_key(self, mocker):
+ mocker.patch("os.getenv", return_value=None)
+ with pytest.raises(ValueError) as _:
+ OpenRouterLLM()
+
+ def test_openrouter_llm_call_uninitialized_client(self, openrouter_llm):
+ # Set the client to None to simulate an uninitialized client
+ openrouter_llm.client = None
+ with pytest.raises(ValueError) as e:
+ llm_input = [Message(role="user", content="test")]
+ openrouter_llm(llm_input)
+ assert "OpenRouter client is not initialized." in str(e.value)
+
+ def test_openrouter_llm_init_exception(self, mocker):
+ mocker.patch("os.getenv", return_value="fake-api-key")
+ mocker.patch("openai.OpenAI", side_effect=Exception("Initialization error"))
+ with pytest.raises(ValueError) as e:
+ OpenRouterLLM()
+ assert (
+ "OpenRouter API client failed to initialize. Error: Initialization error"
+ in str(e.value)
+ )
+
+ def test_openrouter_llm_call_success(self, openrouter_llm, mocker):
+ mock_completion = mocker.MagicMock()
+ mock_completion.choices[0].message.content = "test"
+
+ mocker.patch("os.getenv", return_value="fake-api-key")
+ mocker.patch.object(
+ openrouter_llm.client.chat.completions,
+ "create",
+ return_value=mock_completion,
+ )
+ llm_input = [Message(role="user", content="test")]
+ output = openrouter_llm(llm_input)
+ assert output == "test"
diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py
index 32754997..495d1bdc 100644
--- a/tests/unit/test_layer.py
+++ b/tests/unit/test_layer.py
@@ -92,6 +92,14 @@ def routes():
]
+@pytest.fixture
+def dynamic_routes():
+ return [
+ Route(name="Route 1", utterances=["Hello", "Hi"], function_schema="test"),
+ Route(name="Route 2", utterances=["Goodbye", "Bye", "Au revoir"]),
+ ]
+
+
class TestRouteLayer:
def test_initialization(self, openai_encoder, routes):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
@@ -106,7 +114,12 @@ def test_initialization(self, openai_encoder, routes):
def test_initialization_different_encoders(self, cohere_encoder, openai_encoder):
route_layer_cohere = RouteLayer(encoder=cohere_encoder)
assert route_layer_cohere.score_threshold == 0.3
+ route_layer_openai = RouteLayer(encoder=openai_encoder)
+ assert route_layer_openai.score_threshold == 0.82
+ def test_initialization_dynamic_route(self, cohere_encoder, openai_encoder):
+ route_layer_cohere = RouteLayer(encoder=cohere_encoder)
+ assert route_layer_cohere.score_threshold == 0.3
route_layer_openai = RouteLayer(encoder=openai_encoder)
assert route_layer_openai.score_threshold == 0.82
diff --git a/tests/unit/test_route.py b/tests/unit/test_route.py
index 09a5d235..33a9ac13 100644
--- a/tests/unit/test_route.py
+++ b/tests/unit/test_route.py
@@ -1,6 +1,8 @@
-from unittest.mock import Mock, patch # , AsyncMock
+from unittest.mock import patch # , AsyncMock
-# import pytest
+import pytest
+
+from semantic_router.llms import BaseLLM
from semantic_router.route import Route, is_valid
@@ -41,11 +43,9 @@ def test_is_valid_with_invalid_json():
mock_logger.error.assert_called_once()
-class TestRoute:
- @patch("semantic_router.route.llm", new_callable=Mock)
- def test_generate_dynamic_route(self, mock_llm):
- print(f"mock_llm: {mock_llm}")
- mock_llm.return_value = """
+class MockLLM(BaseLLM):
+ def __call__(self, prompt):
+ llm_output = """
{
"name": "test_function",
@@ -58,8 +58,28 @@ def test_generate_dynamic_route(self, mock_llm):
}
"""
+ return llm_output
+
+
+class TestRoute:
+ def test_value_error_in_route_call(self):
function_schema = {"name": "test_function", "type": "function"}
- route = Route._generate_dynamic_route(function_schema)
+
+ route = Route(
+ name="test_function",
+ utterances=["utterance1", "utterance2"],
+ function_schema=function_schema,
+ )
+
+ with pytest.raises(ValueError):
+ route("test_query")
+
+ def test_generate_dynamic_route(self):
+ mock_llm = MockLLM(name="test")
+ function_schema = {"name": "test_function", "type": "function"}
+ route = Route._generate_dynamic_route(
+ llm=mock_llm, function_schema=function_schema
+ )
assert route.name == "test_function"
assert route.utterances == [
"example_utterance_1",
@@ -105,6 +125,7 @@ def test_to_dict(self):
"utterances": ["utterance"],
"description": None,
"function_schema": None,
+ "llm": None,
}
assert route.to_dict() == expected_dict
@@ -114,28 +135,15 @@ def test_from_dict(self):
assert route.name == "test"
assert route.utterances == ["utterance"]
- @patch("semantic_router.route.llm", new_callable=Mock)
- def test_from_dynamic_route(self, mock_llm):
+ def test_from_dynamic_route(self):
# Mock the llm function
- mock_llm.return_value = """
-
- {
- "name": "test_function",
- "utterances": [
- "example_utterance_1",
- "example_utterance_2",
- "example_utterance_3",
- "example_utterance_4",
- "example_utterance_5"]
- }
-
- """
+ mock_llm = MockLLM(name="test")
def test_function(input: str):
"""Test function docstring"""
pass
- dynamic_route = Route.from_dynamic_route(test_function)
+ dynamic_route = Route.from_dynamic_route(llm=mock_llm, entity=test_function)
assert dynamic_route.name == "test_function"
assert dynamic_route.utterances == [
diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py
index 97b5028e..a41d5fa7 100644
--- a/tests/unit/test_schema.py
+++ b/tests/unit/test_schema.py
@@ -1,9 +1,11 @@
import pytest
+from pydantic import ValidationError
from semantic_router.schema import (
CohereEncoder,
Encoder,
EncoderType,
+ Message,
OpenAIEncoder,
)
@@ -38,3 +40,27 @@ def test_encoder_call_method(self, mocker):
encoder = Encoder(type="openai", name="test-engine")
result = encoder(["test"])
assert result == [0.1, 0.2, 0.3]
+
+
+class TestMessageDataclass:
+ def test_message_creation(self):
+ message = Message(role="user", content="Hello!")
+ assert message.role == "user"
+ assert message.content == "Hello!"
+
+ with pytest.raises(ValidationError):
+ Message(user_role="invalid_role", message="Hello!")
+
+ def test_message_to_openai(self):
+ message = Message(role="user", content="Hello!")
+ openai_format = message.to_openai()
+ assert openai_format == {"role": "user", "content": "Hello!"}
+
+ message = Message(role="invalid_role", content="Hello!")
+ with pytest.raises(ValueError):
+ message.to_openai()
+
+ def test_message_to_cohere(self):
+ message = Message(role="user", content="Hello!")
+ cohere_format = message.to_cohere()
+ assert cohere_format == {"role": "user", "message": "Hello!"}