Skip to content

Commit

Permalink
feat(builder): add Azure Open AI Compatibility (#269)
Browse files Browse the repository at this point in the history
* feat(llm): add Azure OpenAI client and vectorization support

* chore: add .DS_Store to .gitignore

* refactor(llm):add description for api_version and default value

* refactor(vectorize_model): added description for ap_version and default values for some params

* refactor(openai_model): enhance docstring for Azure AD token and deployment parameters
  • Loading branch information
joseosvaldo16 authored Jan 14, 2025
1 parent 671a9a0 commit 6494fd2
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
.idea/
.venv/
__pycache__/
.DS_Store
121 changes: 120 additions & 1 deletion kag/common/llm/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@


import json
from openai import OpenAI
from openai import OpenAI, AzureOpenAI
import logging

from kag.interface import LLMClient
from tenacity import retry, stop_after_attempt
from typing import Callable

logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)

AzureADTokenProvider = Callable[[], str]

@LLMClient.register("maas")
@LLMClient.register("openai")

class OpenAIClient(LLMClient):
"""
A client class for interacting with the OpenAI API.
Expand Down Expand Up @@ -134,3 +137,119 @@ def call_with_json_parse(self, prompt):
except:
return rsp
return json_result
@LLMClient.register("azure_openai")
class AzureOpenAIClient (LLMClient):
def __init__(
self,
api_key: str,
base_url: str,
model: str,
stream: bool = False,
api_version: str = "2024-12-01-preview",
temperature: float = 0.7,
azure_deployment: str = None,
timeout: float = None,
azure_ad_token: str = None,
azure_ad_token_provider: AzureADTokenProvider = None,
):
"""
Initializes the AzureOpenAIClient instance.
Args:
api_key (str): The API key for accessing the Azure OpenAI API.
api_version (str): The API version for the Azure OpenAI API (eg. "2024-12-01-preview, 2024-10-01-preview,2024-05-01-preview").
base_url (str): The base URL for the Azure OpenAI API.
azure_deployment (str): The deployment name for the Azure OpenAI model
model (str): The default model to use for requests.
stream (bool, optional): Whether to stream the response. Defaults to False.
temperature (float, optional): The temperature parameter for the model. Defaults to 0.7.
timeout (float): The timeout duration for the service request. Defaults to None, means no timeout.
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
"""

self.api_key = api_key
self.base_url = base_url
self.azure_deployment = azure_deployment
self.model = model
self.stream = stream
self.temperature = temperature
self.timeout = timeout
self.api_version = api_version
self.azure_ad_token = azure_ad_token
self.azure_ad_token_provider = azure_ad_token_provider
self.client = AzureOpenAI(api_key=self.api_key, base_url=self.base_url,azure_deployment=self.azure_deployment ,model=self.model,api_version=self.api_version, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider)
self.check()

def __call__(self, prompt: str, image_url: str = None):
"""
Executes a model request when the object is called and returns the result.
Parameters:
prompt (str): The prompt provided to the model.
Returns:
str: The response content generated by the model.
"""
# Call the model with the given prompt and return the response
if image_url:
message = [
{"role": "system", "content": "you are a helpful assistant"},
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image_url}},
],
},
]
response = self.client.chat.completions.create(
model=self.model,
messages=message,
stream=self.stream,
temperature=self.temperature,
timeout=self.timeout,
)
rsp = response.choices[0].message.content
return rsp

else:
message = [
{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": prompt},
]
response = self.client.chat.completions.create(
model=self.model,
messages=message,
stream=self.stream,
temperature=self.temperature,
timeout=self.timeout,
)
rsp = response.choices[0].message.content
return rsp
@retry(stop=stop_after_attempt(3))
def call_with_json_parse(self, prompt):
"""
Calls the model and attempts to parse the response into JSON format.
Parameters:
prompt (str): The prompt provided to the model.
Returns:
Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response.
"""
# Call the model and attempt to parse the response into JSON format
rsp = self(prompt)
_end = rsp.rfind("```")
_start = rsp.find("```json")
if _end != -1 and _start != -1:
json_str = rsp[_start + len("```json") : _end].strip()
else:
json_str = rsp
try:
json_result = json.loads(json_str)
except:
return rsp
return json_result
72 changes: 70 additions & 2 deletions kag/common/vectorize_model/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
# or implied.

from typing import Union, Iterable
from openai import OpenAI
from openai import OpenAI, AzureOpenAI
from kag.interface import VectorizeModelABC, EmbeddingVector

from typing import Callable

@VectorizeModelABC.register("openai")
class OpenAIVectorizeModel(VectorizeModelABC):
Expand Down Expand Up @@ -65,3 +65,71 @@ def vectorize(
else:
assert len(results) == len(texts)
return results

@VectorizeModelABC.register("azure_openai")
class AzureOpenAIVectorizeModel(VectorizeModelABC):
''' A class that extends the VectorizeModelABC base class.
It invokes Azure OpenAI or Azure OpenAI-compatible embedding services to convert texts into embedding vectors.
'''

def __init__(
self,
base_url: str,
api_key: str,
model: str = "text-embedding-ada-002",
api_version: str = "2024-12-01-preview",
vector_dimensions: int = None,
timeout: float = None,
azure_deployment: str = None,
azure_ad_token: str = None,
azure_ad_token_provider: Callable = None,
):
"""
Initializes the AzureOpenAIVectorizeModel instance.
Args:
model (str, optional): The model to use for embedding. Defaults to "text-embedding-3-small".
api_key (str, optional): The API key for accessing the Azure OpenAI service. Defaults to "".
api_version (str): The API version for the Azure OpenAI API (eg. "2024-12-01-preview, 2024-10-01-preview,2024-05-01-preview").
base_url (str, optional): The base URL for the Azure OpenAI service. Defaults to "".
vector_dimensions (int, optional): The number of dimensions for the embedding vectors. Defaults to None.
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
"""
super().__init__(vector_dimensions)
self.model = model
self.timeout = timeout
self.client = AzureOpenAI(
api_key=api_key,
base_url=base_url,
azure_deployment=azure_deployment,
model=model,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
)

def vectorize(
self, texts: Union[str, Iterable[str]]
) -> Union[EmbeddingVector, Iterable[EmbeddingVector]]:
"""
Vectorizes a text string into an embedding vector or multiple text strings into multiple embedding vectors.
Args:
texts (Union[str, Iterable[str]]): The text or texts to vectorize.
Returns:
Union[EmbeddingVector, Iterable[EmbeddingVector]]: The embedding vector(s) of the text(s).
"""
results = self.client.embeddings.create(
input=texts, model=self.model, timeout=self.timeout
)
results = [item.embedding for item in results.data]
if isinstance(texts, str):
assert len(results) == 1
return results[0]
else:
assert len(results) == len(texts)
return results

0 comments on commit 6494fd2

Please sign in to comment.