Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added AzureOpenAILLM #112

Merged
merged 13 commits into from
Jan 23, 2024
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ output
node_modules
package-lock.json
package.json

test.ipynb
```
15 changes: 1 addition & 14 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ fastembed = {version = "^0.1.3", optional = true, python = "<3.12"}
torch = {version = "^2.1.2", optional = true}
transformers = {version = "^4.36.2", optional = true}
llama-cpp-python = {version = "^0.2.28", optional = true}
black = "^23.12.1"

[tool.poetry.extras]
hybrid = ["pinecone-text"]
Expand All @@ -36,7 +37,6 @@ local = ["torch", "transformers", "llama-cpp-python"]
[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.0"
ruff = "^0.1.5"
black = {extras = ["jupyter"], version = "^23.12.0"}
pytest = "^7.4.3"
pytest-mock = "^3.12.0"
pytest-cov = "^4.1.0"
Expand Down
2 changes: 2 additions & 0 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,12 @@ def __call__(self, text: str) -> RouteChoice:
"default. Ensure API key is set in OPENAI_API_KEY environment "
"variable."
)

self.llm = OpenAILLM()
route.llm = self.llm
else:
route.llm = self.llm
logger.info(f"LLM `{route.llm}` is chosen")
return route(text)
else:
# if no route passes threshold, return empty route choice
Expand Down
3 changes: 2 additions & 1 deletion semantic_router/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from semantic_router.llms.cohere import CohereLLM
from semantic_router.llms.openai import OpenAILLM
from semantic_router.llms.openrouter import OpenRouterLLM
from semantic_router.llms.zure import AzureOpenAILLM

__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM"]
__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM", "AzureOpenAILLM"]
5 changes: 3 additions & 2 deletions semantic_router/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def _is_valid_inputs(
param_types = [
info.split(":")[1].strip().split("=")[0].strip() for info in param_info
]

for name, type_str in zip(param_names, param_types):
if name not in inputs:
logger.error(f"Input {name} missing from query")
Expand Down Expand Up @@ -76,12 +75,14 @@ def extract_function_inputs(
"""
llm_input = [Message(role="user", content=prompt)]
output = self(llm_input)

if not output:
raise Exception("No output generated for extract function input")

output = output.replace("'", '"').strip().rstrip(",")

logger.info(f"LLM output: {output}")
function_inputs = json.loads(output)
logger.info(f"Function inputs: {function_inputs}")
if not self._is_valid_inputs(function_inputs, function_schema):
raise ValueError("Invalid inputs")
return function_inputs
61 changes: 61 additions & 0 deletions semantic_router/llms/zure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
from typing import List, Optional

import openai

from semantic_router.llms import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.logger import logger


class AzureOpenAILLM(BaseLLM):
client: Optional[openai.AzureOpenAI]
temperature: Optional[float]
max_tokens: Optional[int]

def __init__(
self,
name: Optional[str] = None,
openai_api_key: Optional[str] = None,
azure_endpoint: Optional[str] = None,
temperature: float = 0.01,
max_tokens: int = 200,
api_version="2023-07-01-preview",
):
if name is None:
name = os.getenv("OPENAI_CHAT_MODEL_NAME", "gpt-3.5-turbo")
super().__init__(name=name)
api_key = openai_api_key or os.getenv("AZURE_OPENAI_API_KEY")
if api_key is None:
raise ValueError("AzureOpenAI API key cannot be 'None'.")

Check warning on line 30 in semantic_router/llms/zure.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/llms/zure.py#L30

Added line #L30 was not covered by tests
azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
if azure_endpoint is None:
raise ValueError("Azure endpoint API key cannot be 'None'.")
try:
self.client = openai.AzureOpenAI(
api_key=api_key, azure_endpoint=azure_endpoint, api_version=api_version
)
except Exception as e:
raise ValueError(f"AzureOpenAI API client failed to initialize. Error: {e}")
self.temperature = temperature
self.max_tokens = max_tokens

def __call__(self, messages: List[Message]) -> str:
if self.client is None:
raise ValueError("AzureOpenAI client is not initialized.")
try:
completion = self.client.chat.completions.create(
model=self.name,
messages=[m.to_openai() for m in messages],
temperature=self.temperature,
max_tokens=self.max_tokens,
)

output = completion.choices[0].message.content

if not output:
raise Exception("No output generated")

Check warning on line 57 in semantic_router/llms/zure.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/llms/zure.py#L57

Added line #L57 was not covered by tests
return output
except Exception as e:
logger.error(f"LLM error: {e}")
raise Exception(f"LLM error: {e}")

Check warning on line 61 in semantic_router/llms/zure.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/llms/zure.py#L59-L61

Added lines #L59 - L61 were not covered by tests
47 changes: 24 additions & 23 deletions semantic_router/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Route(BaseModel):
llm: Optional[BaseLLM] = None

def __call__(self, query: str) -> RouteChoice:
logger.info(f"this is the llm passed to route object {self.llm}")
if self.function_schema:
if not self.llm:
raise ValueError(
Expand Down Expand Up @@ -96,29 +97,29 @@ def _generate_dynamic_route(cls, llm: BaseLLM, function_schema: Dict[str, Any]):
logger.info("Generating dynamic route...")

prompt = f"""
You are tasked to generate a JSON configuration based on the provided
function schema. Please follow the template below, no other tokens allowed:

<config>
{{
"name": "<function_name>",
"utterances": [
"<example_utterance_1>",
"<example_utterance_2>",
"<example_utterance_3>",
"<example_utterance_4>",
"<example_utterance_5>"]
}}
</config>

Only include the "name" and "utterances" keys in your answer.
The "name" should match the function name and the "utterances"
should comprise a list of 5 example phrases that could be used to invoke
the function. Use real values instead of placeholders.

Input schema:
{function_schema}
"""
You are tasked to generate a JSON configuration based on the provided
function schema. Please follow the template below, no other tokens allowed:

<config>
{{
"name": "<function_name>",
"utterances": [
"<example_utterance_1>",
"<example_utterance_2>",
"<example_utterance_3>",
"<example_utterance_4>",
"<example_utterance_5>"]
}}
</config>

Only include the "name" and "utterances" keys in your answer.
The "name" should match the function name and the "utterances"
should comprise a list of 5 example phrases that could be used to invoke
the function. Use real values instead of placeholders.

Input schema:
{function_schema}
"""

llm_input = [Message(role="user", content=prompt)]
output = llm(llm_input)
Expand Down
94 changes: 94 additions & 0 deletions tests/unit/llms/test_llm_azure_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import pytest

from semantic_router.llms import AzureOpenAILLM
from semantic_router.schema import Message


@pytest.fixture
def azure_openai_llm(mocker):
mocker.patch("openai.Client")
return AzureOpenAILLM(openai_api_key="test_api_key", azure_endpoint="test_endpoint")


class TestOpenAILLM:
def test_azure_openai_llm_init_with_api_key(self, azure_openai_llm):
assert azure_openai_llm.client is not None, "Client should be initialized"
assert (
azure_openai_llm.name == "gpt-3.5-turbo"
), "Default name not set correctly"

def test_azure_openai_llm_init_success(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
llm = AzureOpenAILLM()
assert llm.client is not None

def test_azure_openai_llm_init_without_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
with pytest.raises(ValueError) as _:
AzureOpenAILLM()

# def test_azure_openai_llm_init_without_azure_endpoint(self, mocker):
# mocker.patch("os.getenv", side_effect=[None, "fake-api-key"])
# with pytest.raises(ValueError) as e:
# AzureOpenAILLM(openai_api_key="test_api_key")
# assert "Azure endpoint API key cannot be 'None'." in str(e.value)

def test_azure_openai_llm_init_without_azure_endpoint(self, mocker):
mocker.patch(
"os.getenv",
side_effect=lambda key, default=None: {
"OPENAI_CHAT_MODEL_NAME": "test-model-name"
}.get(key, default),
)
with pytest.raises(ValueError) as e:
AzureOpenAILLM(openai_api_key="test_api_key")
assert "Azure endpoint API key cannot be 'None'" in str(e.value)

def test_azure_openai_llm_call_uninitialized_client(self, azure_openai_llm):
# Set the client to None to simulate an uninitialized client
azure_openai_llm.client = None
with pytest.raises(ValueError) as e:
llm_input = [Message(role="user", content="test")]
azure_openai_llm(llm_input)
assert "AzureOpenAI client is not initialized." in str(e.value)

def test_azure_openai_llm_init_exception(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch(
"openai.AzureOpenAI", side_effect=Exception("Initialization error")
)
with pytest.raises(ValueError) as e:
AzureOpenAILLM()
assert (
"AzureOpenAI API client failed to initialize. Error: Initialization error"
in str(e.value)
)

def test_azure_openai_llm_temperature_max_tokens_initialization(self):
test_temperature = 0.5
test_max_tokens = 100
azure_llm = AzureOpenAILLM(
openai_api_key="test_api_key",
azure_endpoint="test_endpoint",
temperature=test_temperature,
max_tokens=test_max_tokens,
)

assert (
azure_llm.temperature == test_temperature
), "Temperature not set correctly"
assert azure_llm.max_tokens == test_max_tokens, "Max tokens not set correctly"

def test_azure_openai_llm_call_success(self, azure_openai_llm, mocker):
mock_completion = mocker.MagicMock()
mock_completion.choices[0].message.content = "test"

mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch.object(
azure_openai_llm.client.chat.completions,
"create",
return_value=mock_completion,
)
llm_input = [Message(role="user", content="test")]
output = azure_openai_llm(llm_input)
assert output == "test"
Loading