Skip to content

Commit

Permalink
Merge branch 'main' into voting-support
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam authored Mar 15, 2024
2 parents fcb2d1a + 1394adf commit 3245c54
Show file tree
Hide file tree
Showing 12 changed files with 293 additions and 158 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ politics = Route(
utterances=[
"isn't politics the best thing ever",
"why don't you tell me about your political opinions",
"don't you just love the president" "don't you just hate the president",
"don't you just love the president",
"they're going to destroy this country!",
"they will save the country!",
],
Expand Down Expand Up @@ -147,4 +147,4 @@ Daniel Avila, [Semantic Router: Enhancing Control in LLM Conversations](https://

Yogendra Sisodia, [Stop Chat-GPT From Going Rogue In Production With Semantic Router](https://medium.com/@scholarly360/stop-chat-gpt-from-going-rogue-in-production-with-semantic-router-937a4768ae19), Medium

Aniket Hingane, [LLM Apps: Why you Must Know Semantic Router in 2024: Part 1](https://medium.com/@learn-simplified/llm-apps-why-you-must-know-semantic-router-in-2024-part-1-bfbda81374c5), Medium
Aniket Hingane, [LLM Apps: Why you Must Know Semantic Router in 2024: Part 1](https://medium.com/@learn-simplified/llm-apps-why-you-must-know-semantic-router-in-2024-part-1-bfbda81374c5), Medium
258 changes: 146 additions & 112 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ python = ">=3.9,<3.13"
pydantic = "^2.5.3"
openai = "^1.10.0"
cohere = "^4.32"
mistralai= "^0.0.12"
mistralai= {version = "^0.0.12", optional = true}
numpy = "^1.25.2"
colorlog = "^6.8.0"
pyyaml = "^6.0.1"
pinecone-text = {version = "^0.7.1", optional = true}
fastembed = {version = "^0.1.3", optional = true, python = "<3.12"}
fastembed = {version = "^0.2.4", optional = true, python = "<3.12"}
torch = {version = "^2.1.0", optional = true}
transformers = {version = "^4.36.2", optional = true}
llama-cpp-python = {version = "^0.2.28", optional = true}
Expand All @@ -44,6 +44,7 @@ local = ["torch", "transformers", "llama-cpp-python"]
pinecone = ["pinecone-client"]
vision = ["torch", "torchvision", "transformers", "pillow"]
processing = ["matplotlib"]
mistralai = ["mistralai"]

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.0"
Expand All @@ -67,4 +68,4 @@ build-backend = "poetry.core.masonry.api"
line-length = 88

[tool.mypy]
ignore_missing_imports = true
ignore_missing_imports = true
4 changes: 2 additions & 2 deletions semantic_router/encoders/fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(

def _initialize_client(self):
try:
from fastembed.embedding import FlagEmbedding as Embedding
from fastembed import TextEmbedding
except ImportError:
raise ImportError(
"Please install fastembed to use FastEmbedEncoder. "
Expand All @@ -39,7 +39,7 @@ def _initialize_client(self):

embedding_args = {k: v for k, v in embedding_args.items() if v is not None}

embedding = Embedding(**embedding_args)
embedding = TextEmbedding(**embedding_args)
return embedding

def __call__(self, docs: List[str]) -> List[List[float]]:
Expand Down
41 changes: 30 additions & 11 deletions semantic_router/encoders/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@

import os
from time import sleep
from typing import List, Optional
from typing import List, Optional, Any

from mistralai.client import MistralClient
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse

from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault
from pydantic.v1 import PrivateAttr


class MistralEncoder(BaseEncoder):
"""Class to encode text using MistralAI"""

client: Optional[MistralClient]
_client: Any = PrivateAttr()
_mistralai: Any = PrivateAttr()
type: str = "mistral"

def __init__(
Expand All @@ -27,33 +26,53 @@ def __init__(
if name is None:
name = EncoderDefault.MISTRAL.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold)
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
self._client, self._mistralai = self._initialize_client(mistralai_api_key)

def _initialize_client(self, api_key):
try:
import mistralai
from mistralai.client import MistralClient
except ImportError:
raise ImportError(
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)

api_key = api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("Mistral API key not provided")
try:
self.client = MistralClient(api_key=api_key)
client = MistralClient(api_key=api_key)
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e
return client, mistralai

def __call__(self, docs: List[str]) -> List[List[float]]:
if self.client is None:
if self._client is None:
raise ValueError("Mistral client not initialized")
embeds = None
error_message = ""

# Exponential backoff
for _ in range(3):
try:
embeds = self.client.embeddings(model=self.name, input=docs)
embeds = self._client.embeddings(model=self.name, input=docs)
if embeds.data:
break
except MistralException as e:
except self._mistralai.exceptions.MistralException as e:
sleep(2**_)
error_message = str(e)
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e

if not embeds or not isinstance(embeds, EmbeddingResponse) or not embeds.data:
if (
not embeds
or not isinstance(
embeds, self._mistralai.models.embeddings.EmbeddingResponse
)
or not embeds.data
):
raise ValueError(f"No embeddings returned from MistralAI: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
6 changes: 5 additions & 1 deletion semantic_router/encoders/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class OpenAIEncoder(BaseEncoder):
def __init__(
self,
name: Optional[str] = None,
openai_base_url: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_org_id: Optional[str] = None,
score_threshold: float = 0.82,
Expand All @@ -29,11 +30,14 @@ def __init__(
name = EncoderDefault.OPENAI.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold)
api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
base_url = openai_base_url or os.getenv("OPENAI_BASE_URL")
openai_org_id = openai_org_id or os.getenv("OPENAI_ORG_ID")
if api_key is None:
raise ValueError("OpenAI API key cannot be 'None'.")
try:
self.client = openai.Client(api_key=api_key, organization=openai_org_id)
self.client = openai.Client(
base_url=base_url, api_key=api_key, organization=openai_org_id
)
except Exception as e:
raise ValueError(
f"OpenAI API client failed to initialize. Error: {e}"
Expand Down
2 changes: 2 additions & 0 deletions semantic_router/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from semantic_router.llms.base import BaseLLM
from semantic_router.llms.cohere import CohereLLM
from semantic_router.llms.llamacpp import LlamaCppLLM
from semantic_router.llms.mistral import MistralAILLM
from semantic_router.llms.openai import OpenAILLM
from semantic_router.llms.openrouter import OpenRouterLLM
Expand All @@ -8,6 +9,7 @@
__all__ = [
"BaseLLM",
"OpenAILLM",
"LlamaCppLLM",
"OpenRouterLLM",
"CohereLLM",
"AzureOpenAILLM",
Expand Down
27 changes: 20 additions & 7 deletions semantic_router/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,27 @@
from pathlib import Path
from typing import Any, Optional

from llama_cpp import Llama, LlamaGrammar

from semantic_router.llms.base import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.logger import logger

from pydantic.v1 import PrivateAttr


class LlamaCppLLM(BaseLLM):
llm: Llama
llm: Any
temperature: float
max_tokens: Optional[int] = 200
grammar: Optional[LlamaGrammar] = None
grammar: Optional[Any] = None
_llama_cpp: Any = PrivateAttr()

def __init__(
self,
llm: Llama,
llm: Any,
name: str = "llama.cpp",
temperature: float = 0.2,
max_tokens: Optional[int] = 200,
grammar: Optional[LlamaGrammar] = None,
grammar: Optional[Any] = None,
):
super().__init__(
name=name,
Expand All @@ -30,6 +31,18 @@ def __init__(
max_tokens=max_tokens,
grammar=grammar,
)

try:
import llama_cpp
except ImportError:
raise ImportError(
"Please install LlamaCPP to use Llama CPP llm. "
"You can install it with: "
"`pip install 'semantic-router[local]'`"
)
self._llama_cpp = llama_cpp
llm = self._llama_cpp.Llama
grammar = self._llama_cpp.LlamaGrammar
self.llm = llm
self.temperature = temperature
self.max_tokens = max_tokens
Expand Down Expand Up @@ -62,7 +75,7 @@ def _grammar(self):
grammar_path = Path(__file__).parent.joinpath("grammars", "json.gbnf")
assert grammar_path.exists(), f"{grammar_path}\ndoes not exist"
try:
self.grammar = LlamaGrammar.from_file(grammar_path)
self.grammar = self._llama_cpp.LlamaGrammar.from_file(grammar_path)
yield
finally:
self.grammar = None
Expand Down
42 changes: 32 additions & 10 deletions semantic_router/llms/mistral.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import os
from typing import List, Optional
from typing import List, Optional, Any

from mistralai.client import MistralClient

from semantic_router.llms import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger

from pydantic.v1 import PrivateAttr


class MistralAILLM(BaseLLM):
client: Optional[MistralClient]
_client: Any = PrivateAttr()
temperature: Optional[float]
max_tokens: Optional[int]
_mistralai: Any = PrivateAttr()

def __init__(
self,
Expand All @@ -24,25 +26,45 @@ def __init__(
if name is None:
name = EncoderDefault.MISTRAL.value["language_model"]
super().__init__(name=name)
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
self._client, self._mistralai = self._initialize_client(mistralai_api_key)
self.temperature = temperature
self.max_tokens = max_tokens

def _initialize_client(self, api_key):
try:
import mistralai
from mistralai.client import MistralClient
except ImportError:
raise ImportError(
"Please install MistralAI to use MistralAI LLM. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)
api_key = api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("MistralAI API key cannot be 'None'.")
try:
self.client = MistralClient(api_key=api_key)
client = MistralClient(api_key=api_key)
except Exception as e:
raise ValueError(
f"MistralAI API client failed to initialize. Error: {e}"
) from e
self.temperature = temperature
self.max_tokens = max_tokens
return client, mistralai

def __call__(self, messages: List[Message]) -> str:
if self.client is None:
if self._client is None:
raise ValueError("MistralAI client is not initialized.")

chat_messages = [
self._mistralai.models.chat_completion.ChatMessage(
role=m.role, content=m.content
)
for m in messages
]
try:
completion = self.client.chat(
completion = self._client.chat(
model=self.name,
messages=[m.to_mistral() for m in messages],
messages=chat_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
Expand Down
Loading

0 comments on commit 3245c54

Please sign in to comment.