Skip to content

Commit

Permalink
started fix on llama cpp optional dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
zahid-syed committed Mar 11, 2024
1 parent 6b901f8 commit 80ba002
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 22 deletions.
24 changes: 11 additions & 13 deletions semantic_router/encoders/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from typing import List, Optional, Any



from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault
from pydantic.v1 import PrivateAttr



class MistralEncoder(BaseEncoder):
"""Class to encode text using MistralAI"""

Expand All @@ -32,26 +30,25 @@ def __init__(
if api_key is None:
raise ValueError("Mistral API key not provided")
self._client = self._initialize_client(mistralai_api_key)

def _initialize_client(self, api_key):
try:
from mistralai.client import MistralClient
except ImportError:
raise ImportError(
raise ImportError(
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)
try:
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse
except ImportError:
raise ImportError(
raise ImportError(
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)


try:
self.client = MistralClient(api_key=api_key)
Expand All @@ -60,10 +57,7 @@ def _initialize_client(self, api_key):
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e


def __call__(self, docs: List[str]) -> List[List[float]]:


if self.client is None:
raise ValueError("Mistral client not initialized")
embeds = None
Expand All @@ -81,7 +75,11 @@ def __call__(self, docs: List[str]) -> List[List[float]]:
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e

if not embeds or not isinstance(embeds, self.embedding_response) or not embeds.data:
if (
not embeds
or not isinstance(embeds, self.embedding_response)
or not embeds.data
):
raise ValueError(f"No embeddings returned from MistralAI: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
return embeddings
23 changes: 17 additions & 6 deletions semantic_router/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
from pathlib import Path
from typing import Any, Optional

from llama_cpp import Llama, LlamaGrammar
# from llama_cpp import Llama, LlamaGrammar

from semantic_router.llms.base import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.logger import logger


class LlamaCppLLM(BaseLLM):
llm: Llama
llm: Any
temperature: float
max_tokens: Optional[int] = 200
grammar: Optional[LlamaGrammar] = None
grammar: Optional[Any] = None

def __init__(
self,
llm: Llama,
llm: Any,
name: str = "llama.cpp",
temperature: float = 0.2,
max_tokens: Optional[int] = 200,
grammar: Optional[LlamaGrammar] = None,
grammar: Optional[Any] = None,
):
super().__init__(
name=name,
Expand All @@ -30,6 +30,17 @@ def __init__(
max_tokens=max_tokens,
grammar=grammar,
)

try:
from llama_cpp import Llama, LlamaGrammar
except ImportError:
raise ImportError(
"Please install LlamaCPP to use Llama CPP llm. "
"You can install it with: "
"`pip install 'semantic-router[llama-cpp-python]'`"
)
llm = Llama
grammar = Optional[LlamaGrammar]
self.llm = llm
self.temperature = temperature
self.max_tokens = max_tokens
Expand Down Expand Up @@ -62,7 +73,7 @@ def _grammar(self):
grammar_path = Path(__file__).parent.joinpath("grammars", "json.gbnf")
assert grammar_path.exists(), f"{grammar_path}\ndoes not exist"
try:
self.grammar = LlamaGrammar.from_file(grammar_path)
self.grammar = self.grammar.from_file(grammar_path)
yield
finally:
self.grammar = None
Expand Down
5 changes: 2 additions & 3 deletions semantic_router/llms/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List, Optional, Any



from semantic_router.llms import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.defaults import EncoderDefault
Expand Down Expand Up @@ -37,7 +36,7 @@ def _initialize_client(self, api_key):
try:
from mistralai.client import MistralClient
except ImportError:
raise ImportError(
raise ImportError(
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
Expand All @@ -48,7 +47,7 @@ def _initialize_client(self, api_key):
raise ValueError(
f"MistralAI API client failed to initialize. Error: {e}"
) from e

def __call__(self, messages: List[Message]) -> str:
if self.client is None:
raise ValueError("MistralAI client is not initialized.")
Expand Down

0 comments on commit 80ba002

Please sign in to comment.