Skip to content

Commit

Permalink
Merge pull request #389 from aurelio-labs/ashraq/openai-max-retry
Browse files Browse the repository at this point in the history
feat: add max_retries to OpenAI and Azure encoders
  • Loading branch information
jamescalam committed Aug 19, 2024
2 parents 153bceb + e1ef5a0 commit 592fd3b
Show file tree
Hide file tree
Showing 6 changed files with 351 additions and 76 deletions.
44 changes: 29 additions & 15 deletions semantic_router/encoders/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class OpenAIEncoder(BaseEncoder):
token_limit: int = 8192 # default value, should be replaced by config
_token_encoder: Any = PrivateAttr()
type: str = "openai"
max_retries: int = 3

def __init__(
self,
Expand All @@ -51,6 +52,7 @@ def __init__(
openai_org_id: Optional[str] = None,
score_threshold: Optional[float] = None,
dimensions: Union[int, NotGiven] = NotGiven(),
max_retries: int = 3,
):
if name is None:
name = EncoderDefault.OPENAI.value["embedding_model"]
Expand All @@ -72,6 +74,8 @@ def __init__(
openai_org_id = openai_org_id or os.getenv("OPENAI_ORG_ID")
if api_key is None:
raise ValueError("OpenAI API key cannot be 'None'.")
if max_retries is not None:
self.max_retries = max_retries
try:
self.client = openai.Client(
base_url=base_url, api_key=api_key, organization=openai_org_id
Expand Down Expand Up @@ -102,14 +106,13 @@ def __call__(self, docs: List[str], truncate: bool = True) -> List[List[float]]:
if self.client is None:
raise ValueError("OpenAI client is not initialized.")
embeds = None
error_message = ""

if truncate:
# check if any document exceeds token limit and truncate if so
docs = [self._truncate(doc) for doc in docs]

# Exponential backoff
for j in range(1, 7):
for j in range(self.max_retries + 1):
try:
embeds = self.client.embeddings.create(
input=docs,
Expand All @@ -119,20 +122,26 @@ def __call__(self, docs: List[str], truncate: bool = True) -> List[List[float]]:
if embeds.data:
break
except OpenAIError as e:
sleep(2**j)
error_message = str(e)
logger.warning(f"Retrying in {2**j} seconds...")
logger.error("Exception occurred", exc_info=True)
if self.max_retries != 0 and j < self.max_retries:
sleep(2**j)
logger.warning(
f"Retrying in {2**j} seconds due to OpenAIError: {e}"
)
else:
raise

except Exception as e:
logger.error(f"OpenAI API call failed. Error: {error_message}")
raise ValueError(f"OpenAI API call failed. Error: {e}") from e
logger.error(f"OpenAI API call failed. Error: {e}")
raise ValueError(f"OpenAI API call failed. Error: {str(e)}") from e

if (
not embeds
or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data
):
logger.info(f"Returned embeddings: {embeds}")
raise ValueError(f"No embeddings returned. Error: {error_message}")
raise ValueError("No embeddings returned.")

embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
Expand All @@ -154,14 +163,13 @@ async def acall(self, docs: List[str], truncate: bool = True) -> List[List[float
if self.async_client is None:
raise ValueError("OpenAI async client is not initialized.")
embeds = None
error_message = ""

if truncate:
# check if any document exceeds token limit and truncate if so
docs = [self._truncate(doc) for doc in docs]

# Exponential backoff
for j in range(1, 7):
for j in range(self.max_retries + 1):
try:
embeds = await self.async_client.embeddings.create(
input=docs,
Expand All @@ -171,11 +179,17 @@ async def acall(self, docs: List[str], truncate: bool = True) -> List[List[float
if embeds.data:
break
except OpenAIError as e:
await asleep(2**j)
error_message = str(e)
logger.warning(f"Retrying in {2**j} seconds...")
logger.error("Exception occurred", exc_info=True)
if self.max_retries != 0 and j < self.max_retries:
await asleep(2**j)
logger.warning(
f"Retrying in {2**j} seconds due to OpenAIError: {e}"
)
else:
raise

except Exception as e:
logger.error(f"OpenAI API call failed. Error: {error_message}")
logger.error(f"OpenAI API call failed. Error: {e}")
raise ValueError(f"OpenAI API call failed. Error: {e}") from e

if (
Expand All @@ -184,7 +198,7 @@ async def acall(self, docs: List[str], truncate: bool = True) -> List[List[float
or not embeds.data
):
logger.info(f"Returned embeddings: {embeds}")
raise ValueError(f"No embeddings returned. Error: {error_message}")
raise ValueError("No embeddings returned.")

embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
49 changes: 27 additions & 22 deletions semantic_router/encoders/zure.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class AzureOpenAIEncoder(BaseEncoder):
azure_endpoint: Optional[str] = None
api_version: Optional[str] = None
model: Optional[str] = None
max_retries: int = 3

def __init__(
self,
Expand All @@ -33,6 +34,7 @@ def __init__(
model: Optional[str] = None, # TODO we should change to `name` JB
score_threshold: float = 0.82,
dimensions: Union[int, NotGiven] = NotGiven(),
max_retries: int = 3,
):
name = deployment_name
if name is None:
Expand All @@ -49,6 +51,8 @@ def __init__(
self.api_key = os.getenv("AZURE_OPENAI_API_KEY")
if self.api_key is None:
raise ValueError("No Azure OpenAI API key provided.")
if max_retries is not None:
self.max_retries = max_retries
if self.deployment_name is None:
self.deployment_name = EncoderDefault.AZURE.value["deployment_name"]
# deployment_name may still be None, but it is optional in the API
Expand Down Expand Up @@ -97,10 +101,9 @@ def __call__(self, docs: List[str]) -> List[List[float]]:
if self.client is None:
raise ValueError("Azure OpenAI client is not initialized.")
embeds = None
error_message = ""

# Exponential backoff
for j in range(3):
for j in range(self.max_retries + 1):
try:
embeds = self.client.embeddings.create(
input=docs,
Expand All @@ -110,23 +113,24 @@ def __call__(self, docs: List[str]) -> List[List[float]]:
if embeds.data:
break
except OpenAIError as e:
# print full traceback
import traceback

traceback.print_exc()
sleep(2**j)
error_message = str(e)
logger.warning(f"Retrying in {2**j} seconds...")
logger.error("Exception occurred", exc_info=True)
if self.max_retries != 0 and j < self.max_retries:
sleep(2**j)
logger.warning(
f"Retrying in {2**j} seconds due to OpenAIError: {e}"
)
else:
raise
except Exception as e:
logger.error(f"Azure OpenAI API call failed. Error: {error_message}")
logger.error(f"Azure OpenAI API call failed. Error: {e}")
raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e

if (
not embeds
or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data
):
raise ValueError(f"No embeddings returned. Error: {error_message}")
raise ValueError("No embeddings returned.")

embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
Expand All @@ -135,10 +139,9 @@ async def acall(self, docs: List[str]) -> List[List[float]]:
if self.async_client is None:
raise ValueError("Azure OpenAI async client is not initialized.")
embeds = None
error_message = ""

# Exponential backoff
for j in range(3):
for j in range(self.max_retries + 1):
try:
embeds = await self.async_client.embeddings.create(
input=docs,
Expand All @@ -147,24 +150,26 @@ async def acall(self, docs: List[str]) -> List[List[float]]:
)
if embeds.data:
break
except OpenAIError as e:
# print full traceback
import traceback

traceback.print_exc()
await asleep(2**j)
error_message = str(e)
logger.warning(f"Retrying in {2**j} seconds...")
except OpenAIError as e:
logger.error("Exception occurred", exc_info=True)
if self.max_retries != 0 and j < self.max_retries:
await asleep(2**j)
logger.warning(
f"Retrying in {2**j} seconds due to OpenAIError: {e}"
)
else:
raise
except Exception as e:
logger.error(f"Azure OpenAI API call failed. Error: {error_message}")
logger.error(f"Azure OpenAI API call failed. Error: {e}")
raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e

if (
not embeds
or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data
):
raise ValueError(f"No embeddings returned. Error: {error_message}")
raise ValueError("No embeddings returned.")

embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
6 changes: 5 additions & 1 deletion semantic_router/utils/function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ def to_ollama(self):
"type": "object",
"properties": {
param.name: {
"description": param.description,
"description": (
param.description
if isinstance(param.description, str)
else None
),
"type": self._ollama_type_mapping(param.type),
}
for param in self.parameters
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/encoders/test_openai_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pytest
from openai import OpenAIError
from semantic_router.encoders.base import BaseEncoder
from semantic_router.encoders.openai import OpenAIEncoder

Expand Down Expand Up @@ -40,7 +41,7 @@ def test_openai_encoder_call_truncation(self, openai_encoder):
os.environ.get("OPENAI_API_KEY") is None, reason="OpenAI API key required"
)
def test_openai_encoder_call_no_truncation(self, openai_encoder):
with pytest.raises(ValueError) as _:
with pytest.raises(OpenAIError) as _:
# default truncation is True
openai_encoder([long_doc], truncate=False)

Expand Down
Loading

0 comments on commit 592fd3b

Please sign in to comment.