Skip to content

Commit

Permalink
Merge pull request #1090 from JohnSnowLabs/fix/augmentation-config-va…
Browse files Browse the repository at this point in the history
…ries-even-when-no-transformations-are-applied

resolved/augmentation errors
  • Loading branch information
chakravarthik27 authored Aug 16, 2024
2 parents f377ff0 + b7f68c1 commit 64475eb
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
7 changes: 6 additions & 1 deletion langtest/augmentation/augmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, task: Union[str, TaskManager], config: Union[str, dict]) -> N
if isinstance(config, str):
self.__config = self.load_config(config)

self.__tests: dict = self.__config.get("tests", [])
self.__tests: Dict[str, Dict[str, dict]] = self.__config.get("tests", [])
if isinstance(task, str):
if task in ["ner", "text-classification", "question-answering"]:
task = TaskManager(task)
Expand Down Expand Up @@ -66,6 +66,8 @@ def augment(self, data: Union[str, Iterable]) -> str:
self.__datafactory = self.__datafactory(file_path=data, task=self.__task)

data = self.__datafactory.load()
elif isinstance(self.__datafactory, DataFactory):
data = self.__datafactory.load()

# generate the augmented data
test_cases = self.__testfactory.transform(self.__task, data, self.__tests)
Expand Down Expand Up @@ -276,6 +278,9 @@ def __initialize_config_df(self) -> pd.DataFrame:
)
df = pd.concat([df, pd.DataFrame(temp_data)], ignore_index=True)

# Convert 'proportion' column to float
df["proportion"] = pd.to_numeric(df["proportion"], errors="coerce")

# normalize the proportion and round it to 2 decimal places
df["normalized_proportion"] = df["proportion"] / df["proportion"].sum()
df["normalized_proportion"] = df["normalized_proportion"].apply(
Expand Down
4 changes: 2 additions & 2 deletions langtest/augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def __init__(

self.__templates.extend(generated_templates[:num_extra_templates])
except Exception as e:
raise Errors.E095(e)
raise Errors.E095(msg=e)

if show_templates:
[print(template) for template in self.__templates]
Expand Down Expand Up @@ -609,7 +609,7 @@ class Templates(BaseModel):
def __post_init__(self):
self.templates = [i.strip('"') for i in self.templates]

@validator("templates", each_item=True)
@validator("templates", each_item=True, allow_reuse=True)
def check_templates(cls, v: str):
if not v:
raise ValueError("No templates generated.")
Expand Down
2 changes: 1 addition & 1 deletion langtest/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class Errors(metaclass=ErrorsWithCodes):
E093 = ("Category cannot be None. Please provide a valid category.")
E094 = ("Unsupported category: '{category}'. Supported categories: {supported_category}")
E095 = ("Failed to make API request: {e}")
E096 = ("Failed to generate the templates in Augmentation: {e}")
E096 = ("Failed to generate the templates in Augmentation: {msg}")


class ColumnNameError(Exception):
Expand Down

0 comments on commit 64475eb

Please sign in to comment.