Skip to content

Commit

Permalink
should fix issues relating to optional dependencies and mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
zahid-syed committed Mar 12, 2024
1 parent 3b0de2d commit 354307e
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 121 deletions.
190 changes: 95 additions & 95 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ python = ">=3.9,<3.13"
pydantic = "^2.5.3"
openai = "^1.10.0"
cohere = "^4.32"
mistralai= {version = "^0.0.12", optional = true}
mistralai = {version = "^0.0.12", optional = true}
numpy = "^1.25.2"
colorlog = "^6.8.0"
pyyaml = "^6.0.1"
Expand Down
20 changes: 9 additions & 11 deletions semantic_router/encoders/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,15 @@ def __init__(
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("Mistral API key not provided")
self._client = self._initialize_client(mistralai_api_key)
(
self._client,
self._embedding_response,
self._mistral_exception,
) = self._initialize_client(mistralai_api_key)

def _initialize_client(self, api_key):
try:
from mistralai.client import MistralClient
except ImportError:
raise ImportError(
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)
try:
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse
except ImportError:
Expand All @@ -51,11 +48,12 @@ def _initialize_client(self, api_key):
)

try:
self._client = MistralClient(api_key=api_key)
self._embedding_response = EmbeddingResponse
self._mistral_exception = MistralException
client = MistralClient(api_key=api_key)
embedding_response = EmbeddingResponse
mistral_exception = MistralException
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e
return client, embedding_response, mistral_exception

def __call__(self, docs: List[str]) -> List[List[float]]:
if self._client is None:
Expand Down
12 changes: 8 additions & 4 deletions semantic_router/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from semantic_router.schema import Message
from semantic_router.utils.logger import logger

from pydantic.v1 import PrivateAttr


class LlamaCppLLM(BaseLLM):
llm: Any
temperature: float
max_tokens: Optional[int] = 200
grammar: Optional[Any] = None
_llama_cpp: Any = PrivateAttr()

def __init__(
self,
Expand All @@ -32,15 +35,16 @@ def __init__(
)

try:
from llama_cpp import Llama, LlamaGrammar
import llama_cpp
except ImportError:
raise ImportError(
"Please install LlamaCPP to use Llama CPP llm. "
"You can install it with: "
"`pip install 'semantic-router[llama-cpp-python]'`"
)
llm = Llama
grammar = Optional[LlamaGrammar]
self._llama_cpp = llama_cpp
llm = self._llama_cpp.Llama
grammar = Optional[self._llama_cpp.LlamaGrammar]
self.llm = llm
self.temperature = temperature
self.max_tokens = max_tokens
Expand Down Expand Up @@ -73,7 +77,7 @@ def _grammar(self):
grammar_path = Path(__file__).parent.joinpath("grammars", "json.gbnf")
assert grammar_path.exists(), f"{grammar_path}\ndoes not exist"
try:
self.grammar = self.grammar.from_file(grammar_path)
self.grammar = self._llama_cpp.LlamaGrammar.from_file(grammar_path)
yield
finally:
self.grammar = None
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/encoders/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def mistralai_encoder(mocker):
class TestMistralEncoder:
def test_mistralai_encoder_init_success(self, mocker):
encoder = MistralEncoder(mistralai_api_key="test_api_key")
assert encoder.client is not None
assert encoder._client is not None

def test_mistralai_encoder_init_no_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
Expand All @@ -23,7 +23,7 @@ def test_mistralai_encoder_init_no_api_key(self, mocker):

def test_mistralai_encoder_call_uninitialized_client(self, mistralai_encoder):
# Set the client to None to simulate an uninitialized client
mistralai_encoder.client = None
mistralai_encoder._client = None
with pytest.raises(ValueError) as e:
mistralai_encoder(["test document"])
assert "Mistral client not initialized" in str(e.value)
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_mistralai_encoder_call_success(self, mistralai_encoder, mocker):

responses = [MistralException("mistralai error"), mock_response]
mocker.patch.object(
mistralai_encoder.client, "embeddings", side_effect=responses
mistralai_encoder._client, "embeddings", side_effect=responses
)
embeddings = mistralai_encoder(["test document"])
assert embeddings == [[0.1, 0.2]]
Expand All @@ -69,7 +69,7 @@ def test_mistralai_encoder_call_with_retries(self, mistralai_encoder, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mocker.patch.object(
mistralai_encoder.client,
mistralai_encoder._client,
"embeddings",
side_effect=MistralException("Test error"),
)
Expand All @@ -83,7 +83,7 @@ def test_mistralai_encoder_call_failure_non_mistralai_error(
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mocker.patch.object(
mistralai_encoder.client,
mistralai_encoder._client,
"embeddings",
side_effect=Exception("Non-MistralException"),
)
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_mistralai_encoder_call_successful_retry(self, mistralai_encoder, mocker

responses = [MistralException("mistralai error"), mock_response]
mocker.patch.object(
mistralai_encoder.client, "embeddings", side_effect=responses
mistralai_encoder._client, "embeddings", side_effect=responses
)
embeddings = mistralai_encoder(["test document"])
assert embeddings == [[0.1, 0.2]]
8 changes: 4 additions & 4 deletions tests/unit/llms/test_llm_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ def mistralai_llm(mocker):

class TestMistralAILLM:
def test_mistralai_llm_init_with_api_key(self, mistralai_llm):
assert mistralai_llm.client is not None, "Client should be initialized"
assert mistralai_llm._client is not None, "Client should be initialized"
assert mistralai_llm.name == "mistral-tiny", "Default name not set correctly"

def test_mistralai_llm_init_success(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
llm = MistralAILLM()
assert llm.client is not None
assert llm._client is not None

def test_mistralai_llm_init_without_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
Expand All @@ -27,7 +27,7 @@ def test_mistralai_llm_init_without_api_key(self, mocker):

def test_mistralai_llm_call_uninitialized_client(self, mistralai_llm):
# Set the client to None to simulate an uninitialized client
mistralai_llm.client = None
mistralai_llm._client = None
with pytest.raises(ValueError) as e:
llm_input = [Message(role="user", content="test")]
mistralai_llm(llm_input)
Expand All @@ -48,7 +48,7 @@ def test_mistralai_llm_call_success(self, mistralai_llm, mocker):

mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch.object(
mistralai_llm.client,
mistralai_llm._client,
"chat",
return_value=mock_completion,
)
Expand Down

0 comments on commit 354307e

Please sign in to comment.