Skip to content

Commit

Permalink
Merge pull request #1102 from JohnSnowLabs/fix/module_error-with-open…
Browse files Browse the repository at this point in the history
…ai-package

Fix/module error with openai package
  • Loading branch information
chakravarthik27 authored Sep 11, 2024
2 parents f1fbdc1 + e312032 commit 4448c71
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 16 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
pull_request:
branches:
- "release/*"
- "patch/*"
- "main"

jobs:
Expand Down
28 changes: 15 additions & 13 deletions langtest/augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from langtest.utils.custom_types.predictions import NERPrediction, SequenceLabel
from langtest.utils.custom_types.sample import NERSample
from langtest.tasks import TaskManager
from ..utils.lib_manager import try_import_lib
from ..errors import Errors


Expand Down Expand Up @@ -358,6 +357,9 @@ def __init__(
# Extend the existing templates list

self.__templates.extend(generated_templates[:num_extra_templates])
except ModuleNotFoundError:
raise ImportError(Errors.E097())

except Exception as e_msg:
raise Errors.E095(e=e_msg)

Expand Down Expand Up @@ -606,19 +608,19 @@ def __generate_templates(
num_extra_templates: int,
model_config: Union[OpenAIConfig, AzureOpenAIConfig] = None,
) -> List[str]:
if try_import_lib("openai"):
from langtest.augmentation.utils import (
generate_templates_azoi, # azoi means Azure OpenAI
generate_templates_openai,
)
"""This method is used to generate extra templates from a given template."""
from langtest.augmentation.utils import (
generate_templates_azoi, # azoi means Azure OpenAI
generate_templates_openai,
)

params = model_config.copy() if model_config else {}
params = model_config.copy() if model_config else {}

if model_config and model_config.get("provider") == "openai":
return generate_templates_openai(template, num_extra_templates, params)
if model_config and model_config.get("provider") == "openai":
return generate_templates_openai(template, num_extra_templates, params)

elif model_config and model_config.get("provider") == "azure":
return generate_templates_azoi(template, num_extra_templates, params)
elif model_config and model_config.get("provider") == "azure":
return generate_templates_azoi(template, num_extra_templates, params)

else:
return generate_templates_openai(template, num_extra_templates)
else:
return generate_templates_openai(template, num_extra_templates)
6 changes: 3 additions & 3 deletions langtest/augmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@ class OpenAIConfig(TypedDict):
class AzureOpenAIConfig(TypedDict):
"""Azure OpenAI Configuration for API Key and Provider."""

from openai.lib.azure import AzureADTokenProvider

azure_endpoint: str
api_version: str
api_key: str
provider: str
azure_deployment: Union[str, None] = None
azure_ad_token: Union[str, None] = (None,)
azure_ad_token_provider: Union[AzureADTokenProvider, None] = (None,)
azure_ad_token_provider = (None,)
organization: Union[str, None] = (None,)


Expand Down Expand Up @@ -76,6 +74,7 @@ def generate_templates_azoi(
template: str, num_extra_templates: int, model_config: AzureOpenAIConfig
):
"""Generate new templates based on the provided template using Azure OpenAI API."""

import openai

if "provider" in model_config:
Expand Down Expand Up @@ -139,6 +138,7 @@ def generate_templates_openai(
template: str, num_extra_templates: int, model_config: OpenAIConfig = OpenAIConfig()
):
"""Generate new templates based on the provided template using OpenAI API."""

import openai

if "provider" in model_config:
Expand Down
1 change: 1 addition & 0 deletions langtest/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ class Errors(metaclass=ErrorsWithCodes):
E094 = ("Unsupported category: '{category}'. Supported categories: {supported_category}")
E095 = ("Failed to make API request: {e}")
E096 = ("Failed to generate the templates in Augmentation: {msg}")
E097 = ("Failed to load openai. Please install it using `pip install openai`")


class ColumnNameError(Exception):
Expand Down

0 comments on commit 4448c71

Please sign in to comment.