From 3b0de2d59299b9c83ff7b32897016b3ea743f8dd Mon Sep 17 00:00:00 2001 From: zahid-syed Date: Mon, 11 Mar 2024 13:22:38 -0400 Subject: [PATCH] fixing pytest issues for optional dependency issues --- semantic_router/encoders/mistral.py | 18 +++++++++--------- semantic_router/llms/mistral.py | 8 ++++---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/semantic_router/encoders/mistral.py b/semantic_router/encoders/mistral.py index ead08545..3c821327 100644 --- a/semantic_router/encoders/mistral.py +++ b/semantic_router/encoders/mistral.py @@ -13,8 +13,8 @@ class MistralEncoder(BaseEncoder): """Class to encode text using MistralAI""" _client: Any = PrivateAttr() - embedding_response: Any = PrivateAttr() - mistral_exception: Any = PrivateAttr() + _embedding_response: Any = PrivateAttr() + _mistral_exception: Any = PrivateAttr() type: str = "mistral" def __init__( @@ -51,14 +51,14 @@ def _initialize_client(self, api_key): ) try: - self.client = MistralClient(api_key=api_key) - self.embedding_response = EmbeddingResponse - self.mistral_exception = MistralException + self._client = MistralClient(api_key=api_key) + self._embedding_response = EmbeddingResponse + self._mistral_exception = MistralException except Exception as e: raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e def __call__(self, docs: List[str]) -> List[List[float]]: - if self.client is None: + if self._client is None: raise ValueError("Mistral client not initialized") embeds = None error_message = "" @@ -66,10 +66,10 @@ def __call__(self, docs: List[str]) -> List[List[float]]: # Exponential backoff for _ in range(3): try: - embeds = self.client.embeddings(model=self.name, input=docs) + embeds = self._client.embeddings(model=self.name, input=docs) if embeds.data: break - except self.mistral_exception as e: + except self._mistral_exception as e: sleep(2**_) error_message = str(e) except Exception as e: @@ -77,7 +77,7 @@ def __call__(self, docs: List[str]) -> List[List[float]]: if ( not embeds - or not isinstance(embeds, self.embedding_response) + or not isinstance(embeds, self._embedding_response) or not embeds.data ): raise ValueError(f"No embeddings returned from MistralAI: {error_message}") diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py index c873bd25..adecd22c 100644 --- a/semantic_router/llms/mistral.py +++ b/semantic_router/llms/mistral.py @@ -11,7 +11,7 @@ class MistralAILLM(BaseLLM): - client: Any = PrivateAttr() + _client: Any = PrivateAttr() temperature: Optional[float] max_tokens: Optional[int] @@ -42,17 +42,17 @@ def _initialize_client(self, api_key): "`pip install 'semantic-router[mistralai]'`" ) try: - self.client = MistralClient(api_key=api_key) + self._client = MistralClient(api_key=api_key) except Exception as e: raise ValueError( f"MistralAI API client failed to initialize. Error: {e}" ) from e def __call__(self, messages: List[Message]) -> str: - if self.client is None: + if self._client is None: raise ValueError("MistralAI client is not initialized.") try: - completion = self.client.chat( + completion = self._client.chat( model=self.name, messages=[m.to_mistral() for m in messages], temperature=self.temperature,