From 6734f706a6a68db10856d0e4a49e0fb48d6dca0a Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Fri, 16 Aug 2024 11:39:56 +0530 Subject: [PATCH 1/3] chore: Fix error message in Augmentation when generating templates --- langtest/augmentation/base.py | 4 ++-- langtest/errors.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/langtest/augmentation/base.py b/langtest/augmentation/base.py index a1d01dfa6..116f30be6 100644 --- a/langtest/augmentation/base.py +++ b/langtest/augmentation/base.py @@ -355,7 +355,7 @@ def __init__( self.__templates.extend(generated_templates[:num_extra_templates]) except Exception as e: - raise Errors.E095(e) + raise Errors.E095(msg=e) if show_templates: [print(template) for template in self.__templates] @@ -609,7 +609,7 @@ class Templates(BaseModel): def __post_init__(self): self.templates = [i.strip('"') for i in self.templates] - @validator("templates", each_item=True) + @validator("templates", each_item=True, allow_reuse=True) def check_templates(cls, v: str): if not v: raise ValueError("No templates generated.") diff --git a/langtest/errors.py b/langtest/errors.py index 4dfc38ce6..d3d7d1bba 100644 --- a/langtest/errors.py +++ b/langtest/errors.py @@ -274,7 +274,7 @@ class Errors(metaclass=ErrorsWithCodes): E093 = ("Category cannot be None. Please provide a valid category.") E094 = ("Unsupported category: '{category}'. Supported categories: {supported_category}") E095 = ("Failed to make API request: {e}") - E096 = ("Failed to generate the templates in Augmentation: {e}") + E096 = ("Failed to generate the templates in Augmentation: {msg}") class ColumnNameError(Exception): From 55d17e1d9bf319314812193c2d541104d35f9e0f Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Fri, 16 Aug 2024 14:30:50 +0530 Subject: [PATCH 2/3] chore: Refactor DataAugmenter to improve template generation and proportion handling --- langtest/augmentation/augmenter.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/langtest/augmentation/augmenter.py b/langtest/augmentation/augmenter.py index 2cd118aab..7af5a30d4 100644 --- a/langtest/augmentation/augmenter.py +++ b/langtest/augmentation/augmenter.py @@ -26,7 +26,7 @@ def __init__(self, task: Union[str, TaskManager], config: Union[str, dict]) -> N if isinstance(config, str): self.__config = self.load_config(config) - self.__tests: dict = self.__config.get("tests", []) + self.__tests: Dict[str, Dict[str, dict]] = self.__config.get("tests", []) if isinstance(task, str): if task in ["ner", "text-classification", "question-answering"]: task = TaskManager(task) @@ -276,6 +276,9 @@ def __initialize_config_df(self) -> pd.DataFrame: ) df = pd.concat([df, pd.DataFrame(temp_data)], ignore_index=True) + # Convert 'proportion' column to float + df["proportion"] = pd.to_numeric(df["proportion"], errors="coerce") + # normalize the proportion and round it to 2 decimal places df["normalized_proportion"] = df["proportion"] / df["proportion"].sum() df["normalized_proportion"] = df["normalized_proportion"].apply( From b7f68c1e161f0da644f0f0b4343650490a82def5 Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Fri, 16 Aug 2024 14:33:06 +0530 Subject: [PATCH 3/3] Refactor DataAugmenter to improve proportion handling --- langtest/augmentation/augmenter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/langtest/augmentation/augmenter.py b/langtest/augmentation/augmenter.py index 7af5a30d4..f587adc27 100644 --- a/langtest/augmentation/augmenter.py +++ b/langtest/augmentation/augmenter.py @@ -66,6 +66,8 @@ def augment(self, data: Union[str, Iterable]) -> str: self.__datafactory = self.__datafactory(file_path=data, task=self.__task) data = self.__datafactory.load() + elif isinstance(self.__datafactory, DataFactory): + data = self.__datafactory.load() # generate the augmented data test_cases = self.__testfactory.transform(self.__task, data, self.__tests)