Skip to content

Commit

Permalink
Merge pull request #1091 from JohnSnowLabs/fix/augmentation-config-va…
Browse files Browse the repository at this point in the history
…ries-even-when-no-transformations-are-applied

Fix/augmentations
  • Loading branch information
chakravarthik27 authored Sep 2, 2024
2 parents 64475eb + 85d7e70 commit cc821c9
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 61 deletions.
86 changes: 26 additions & 60 deletions langtest/augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
import yaml

from langtest.augmentation.utils import AzureOpenAIConfig, OpenAIConfig
from langtest.datahandler.datasource import DataFactory
from langtest.transform import TestFactory
from langtest.transform.utils import create_terminology
Expand Down Expand Up @@ -324,6 +325,7 @@ def __init__(
generate_templates=False,
show_templates=False,
num_extra_templates=10,
model_config: Union[OpenAIConfig, AzureOpenAIConfig] = None,
) -> None:
"""This constructor for the TemplaticAugment class.
Expand All @@ -341,21 +343,23 @@ def __init__(
given_template = self.__templates[:]
for template in given_template:
generated_templates: List[str] = self.__generate_templates(
template, num_extra_templates
template, num_extra_templates, model_config
)

while len(generated_templates) < num_extra_templates:
temp_templates = self.__generate_templates(
template, num_extra_templates
template,
num_extra_templates,
model_config,
)
generated_templates.extend(temp_templates)

if generated_templates:
# Extend the existing templates list

self.__templates.extend(generated_templates[:num_extra_templates])
except Exception as e:
raise Errors.E095(msg=e)
except Exception as e_msg:
raise Errors.E095(e=e_msg)

if show_templates:
[print(template) for template in self.__templates]
Expand Down Expand Up @@ -596,63 +600,25 @@ def add_spaces_around_punctuation(text: str):

return text

def __generate_templates(self, template, num_extra_templates) -> List[str]:
def __generate_templates(
self,
template: str,
num_extra_templates: int,
model_config: Union[OpenAIConfig, AzureOpenAIConfig] = None,
) -> List[str]:
if try_import_lib("openai"):
import openai
from pydantic import BaseModel, validator

client = openai.OpenAI()

class Templates(BaseModel):
templates: List[str]

def __post_init__(self):
self.templates = [i.strip('"') for i in self.templates]

@validator("templates", each_item=True, allow_reuse=True)
def check_templates(cls, v: str):
if not v:
raise ValueError("No templates generated.")
return v.strip('"')

def remove_invalid_templates(self, original_template):
# extract variable names using regex
regexs = r"{([^{}]*)}"
original_vars = re.findall(regexs, original_template)
original_vars = set([var.strip() for var in original_vars])

# remove invalid templates
valid_templates = []
for template in self.templates:
template_vars: List[str] = re.findall(regexs, template)
template_vars = set([var.strip() for var in template_vars])
if template_vars == original_vars:
valid_templates.append(template)
self.templates = valid_templates

prompt = (
f"Based on the provided template, create {num_extra_templates} new and unique templates that are "
"variations on this theme. Present these as a list, with each template as a quoted string. The list should "
"contain only the templates, without any additional text or explanation. Ensure that the structure of "
"these variables remains consistent in each generated template. Note: don't add any extra variables and ignore typo errors.\n\n"
"Template:\n"
f"{template}\n"
)
response = client.beta.chat.completions.parse(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": f"Action: Generate up to {num_extra_templates} templates and ensure that the structure of the variables within the templates remains unchanged and don't add any extra variables.",
},
{"role": "user", "content": prompt},
],
max_tokens=500,
temperature=0,
response_format=Templates,
from langtest.augmentation.utils import (
generate_templates_azoi, # azoi means Azure OpenAI
generate_templates_openai,
)

generated_response = response.choices[0].message.parsed
generated_response.remove_invalid_templates(template)
params = model_config.copy() if model_config else {}

if model_config and model_config.get("provider") == "openai":
return generate_templates_openai(template, num_extra_templates, params)

return generated_response.templates[:num_extra_templates]
elif model_config and model_config.get("provider") == "azure":
return generate_templates_azoi(template, num_extra_templates, params)

else:
return generate_templates_openai(template, num_extra_templates)
174 changes: 174 additions & 0 deletions langtest/augmentation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import re
from typing import List, TypedDict, Union
import os

from pydantic import BaseModel, validator
from langtest.logger import logger


class OpenAIConfig(TypedDict):
"""OpenAI Configuration for API Key and Provider."""

api_key: str = os.environ.get("OPENAI_API_KEY")
base_url: Union[str, None] = None
organization: Union[str, None] = (None,)
project: Union[str, None] = (None,)
provider: str = "openai"


class AzureOpenAIConfig(TypedDict):
"""Azure OpenAI Configuration for API Key and Provider."""

from openai.lib.azure import AzureADTokenProvider

azure_endpoint: str
api_version: str
api_key: str
provider: str
azure_deployment: Union[str, None] = None
azure_ad_token: Union[str, None] = (None,)
azure_ad_token_provider: Union[AzureADTokenProvider, None] = (None,)
organization: Union[str, None] = (None,)


class Templates(BaseModel):
"""Model to validate generated templates."""

templates: List[str]

def __post_init__(self):
"""Post init method to remove quotes from templates."""
self.templates = [i.strip('"') for i in self.templates]
logger.info(f"Generated templates: {self.templates}")

@validator("templates", each_item=True, allow_reuse=True)
def check_templates(cls, v: str):
"""Validator to check if templates are generated."""
if not v:
raise ValueError("No templates generated.")
return v.strip('"')

def remove_invalid_templates(self, original_template):
"""Remove invalid templates based on the original template."""
# extract variable names using regex
regexs = r"{([^{}]*)}"
original_vars = re.findall(regexs, original_template)
original_vars = set([var.strip() for var in original_vars])

# remove invalid templates
valid_templates = []
for template in self.templates:
template_vars: List[str] = re.findall(regexs, template)
template_vars = set([var.strip() for var in template_vars])
if template_vars == original_vars:
valid_templates.append(template)
logger.info(f"Valid template: {template}")
else:
logger.warning(
f"Invalid Variables in template: {template} - {template_vars}"
)

self.templates = valid_templates
logger.info(f"Valid templates: {self.templates}")


def generate_templates_azoi(
template: str, num_extra_templates: int, model_config: AzureOpenAIConfig
):
"""Generate new templates based on the provided template using Azure OpenAI API."""
import openai

if "provider" in model_config:
del model_config["provider"]

client = openai.AzureOpenAI(**model_config)

prompt = (
"Based on the provided template, create {num_extra_templates} new and unique templates that are "
"variations on this theme. Present these as a list, with each template as a quoted string. The list should "
"contain only the templates, without any additional text or explanation. Ensure that the structure of "
"these variables remains consistent in each generated template. Note: don't add any extra variables and ignore typo errors.\n\n"
"Template:\n"
"{template}\n"
)

response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": f"Generate up to {num_extra_templates} templates based on the provided template.\n\n JSON Output Schema: {Templates.schema()}\n",
},
{
"role": "user",
"content": prompt.format(
template="The {ORG} company is located in {LOC}",
num_extra_templates=2,
),
},
{
"role": "assistant",
"content": '["The {ORG} corporation is based out of {LOC}",\n "The {ORG} organization operates in {LOC}"]',
},
{
"role": "user",
"content": prompt.format(
template=template, num_extra_templates=num_extra_templates
),
},
],
temperature=0.1,
max_tokens=1000,
)

import json

try:
clean_response = response.choices[0].message.content.replace("'", '"')
gen_templates = Templates(templates=json.loads(clean_response))
gen_templates.remove_invalid_templates(template)

return gen_templates.templates[:num_extra_templates]

except json.JSONDecodeError as e:
logger.error(f"Error decoding response: {e}")
raise ValueError(f"Error decoding response: {e}")


def generate_templates_openai(
template: str, num_extra_templates: int, model_config: OpenAIConfig = OpenAIConfig()
):
"""Generate new templates based on the provided template using OpenAI API."""
import openai

if "provider" in model_config:
del model_config["provider"]

client = openai.OpenAI(**model_config)

prompt = (
f"Based on the provided template, create {num_extra_templates} new and unique templates that are "
"variations on this theme. Present these as a list, with each template as a quoted string. The list should "
"contain only the templates, without any additional text or explanation. Ensure that the structure of "
"these variables remains consistent in each generated template. Note: don't add any extra variables and ignore typo errors.\n\n"
"Template:\n"
f"{template}\n"
)
response = client.beta.chat.completions.parse(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": f"Action: Generate up to {num_extra_templates} templates and ensure that the structure of the variables within the templates remains unchanged and don't add any extra variables.",
},
{"role": "user", "content": prompt},
],
max_tokens=100,
temperature=0.1,
response_format=Templates,
)

generated_response = response.choices[0].message.parsed
generated_response.remove_invalid_templates(template)

return generated_response.templates[:num_extra_templates]
2 changes: 1 addition & 1 deletion langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import csv
import importlib
import logging
import os
import random
import re
Expand All @@ -11,6 +10,7 @@
import jsonlines
import pandas as pd
from langtest.tasks.task import TaskManager
from langtest.logger import logger as logging

from .format import Formatter
from langtest.utils.custom_types import (
Expand Down

0 comments on commit cc821c9

Please sign in to comment.