From 80ba0024f7f0ae1601cd577605e32ce7054b805d Mon Sep 17 00:00:00 2001 From: zahid-syed Date: Mon, 11 Mar 2024 11:56:46 -0400 Subject: [PATCH] started fix on llama cpp optional dependency --- semantic_router/encoders/mistral.py | 24 +++++++++++------------- semantic_router/llms/llamacpp.py | 23 +++++++++++++++++------ semantic_router/llms/mistral.py | 5 ++--- 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/semantic_router/encoders/mistral.py b/semantic_router/encoders/mistral.py index 5beabb71..0ab56824 100644 --- a/semantic_router/encoders/mistral.py +++ b/semantic_router/encoders/mistral.py @@ -4,13 +4,11 @@ from typing import List, Optional, Any - from semantic_router.encoders import BaseEncoder from semantic_router.utils.defaults import EncoderDefault from pydantic.v1 import PrivateAttr - class MistralEncoder(BaseEncoder): """Class to encode text using MistralAI""" @@ -32,26 +30,25 @@ def __init__( if api_key is None: raise ValueError("Mistral API key not provided") self._client = self._initialize_client(mistralai_api_key) - + def _initialize_client(self, api_key): try: from mistralai.client import MistralClient except ImportError: - raise ImportError( + raise ImportError( "Please install MistralAI to use MistralEncoder. " "You can install it with: " "`pip install 'semantic-router[mistralai]'`" ) try: - from mistralai.exceptions import MistralException - from mistralai.models.embeddings import EmbeddingResponse + from mistralai.exceptions import MistralException + from mistralai.models.embeddings import EmbeddingResponse except ImportError: - raise ImportError( + raise ImportError( "Please install MistralAI to use MistralEncoder. " "You can install it with: " "`pip install 'semantic-router[mistralai]'`" ) - try: self.client = MistralClient(api_key=api_key) @@ -60,10 +57,7 @@ def _initialize_client(self, api_key): except Exception as e: raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e - def __call__(self, docs: List[str]) -> List[List[float]]: - - if self.client is None: raise ValueError("Mistral client not initialized") embeds = None @@ -81,7 +75,11 @@ def __call__(self, docs: List[str]) -> List[List[float]]: except Exception as e: raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e - if not embeds or not isinstance(embeds, self.embedding_response) or not embeds.data: + if ( + not embeds + or not isinstance(embeds, self.embedding_response) + or not embeds.data + ): raise ValueError(f"No embeddings returned from MistralAI: {error_message}") embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] - return embeddings \ No newline at end of file + return embeddings diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index 2586d2e4..527d97c1 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Any, Optional -from llama_cpp import Llama, LlamaGrammar +# from llama_cpp import Llama, LlamaGrammar from semantic_router.llms.base import BaseLLM from semantic_router.schema import Message @@ -10,18 +10,18 @@ class LlamaCppLLM(BaseLLM): - llm: Llama + llm: Any temperature: float max_tokens: Optional[int] = 200 - grammar: Optional[LlamaGrammar] = None + grammar: Optional[Any] = None def __init__( self, - llm: Llama, + llm: Any, name: str = "llama.cpp", temperature: float = 0.2, max_tokens: Optional[int] = 200, - grammar: Optional[LlamaGrammar] = None, + grammar: Optional[Any] = None, ): super().__init__( name=name, @@ -30,6 +30,17 @@ def __init__( max_tokens=max_tokens, grammar=grammar, ) + + try: + from llama_cpp import Llama, LlamaGrammar + except ImportError: + raise ImportError( + "Please install LlamaCPP to use Llama CPP llm. " + "You can install it with: " + "`pip install 'semantic-router[llama-cpp-python]'`" + ) + llm = Llama + grammar = Optional[LlamaGrammar] self.llm = llm self.temperature = temperature self.max_tokens = max_tokens @@ -62,7 +73,7 @@ def _grammar(self): grammar_path = Path(__file__).parent.joinpath("grammars", "json.gbnf") assert grammar_path.exists(), f"{grammar_path}\ndoes not exist" try: - self.grammar = LlamaGrammar.from_file(grammar_path) + self.grammar = self.grammar.from_file(grammar_path) yield finally: self.grammar = None diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py index 42cbe46d..c873bd25 100644 --- a/semantic_router/llms/mistral.py +++ b/semantic_router/llms/mistral.py @@ -2,7 +2,6 @@ from typing import List, Optional, Any - from semantic_router.llms import BaseLLM from semantic_router.schema import Message from semantic_router.utils.defaults import EncoderDefault @@ -37,7 +36,7 @@ def _initialize_client(self, api_key): try: from mistralai.client import MistralClient except ImportError: - raise ImportError( + raise ImportError( "Please install MistralAI to use MistralEncoder. " "You can install it with: " "`pip install 'semantic-router[mistralai]'`" @@ -48,7 +47,7 @@ def _initialize_client(self, api_key): raise ValueError( f"MistralAI API client failed to initialize. Error: {e}" ) from e - + def __call__(self, messages: List[Message]) -> str: if self.client is None: raise ValueError("MistralAI client is not initialized.")