diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 0d23033e7..5dcb68ca3 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -8,6 +8,7 @@ on: pull_request: branches: - "release/*" + - "patch/*" - "main" jobs: diff --git a/langtest/augmentation/base.py b/langtest/augmentation/base.py index ca6718133..2767df6cb 100644 --- a/langtest/augmentation/base.py +++ b/langtest/augmentation/base.py @@ -19,7 +19,6 @@ from langtest.utils.custom_types.predictions import NERPrediction, SequenceLabel from langtest.utils.custom_types.sample import NERSample from langtest.tasks import TaskManager -from ..utils.lib_manager import try_import_lib from ..errors import Errors @@ -358,6 +357,9 @@ def __init__( # Extend the existing templates list self.__templates.extend(generated_templates[:num_extra_templates]) + except ModuleNotFoundError: + raise ImportError(Errors.E097()) + except Exception as e_msg: raise Errors.E095(e=e_msg) @@ -606,19 +608,19 @@ def __generate_templates( num_extra_templates: int, model_config: Union[OpenAIConfig, AzureOpenAIConfig] = None, ) -> List[str]: - if try_import_lib("openai"): - from langtest.augmentation.utils import ( - generate_templates_azoi, # azoi means Azure OpenAI - generate_templates_openai, - ) + """This method is used to generate extra templates from a given template.""" + from langtest.augmentation.utils import ( + generate_templates_azoi, # azoi means Azure OpenAI + generate_templates_openai, + ) - params = model_config.copy() if model_config else {} + params = model_config.copy() if model_config else {} - if model_config and model_config.get("provider") == "openai": - return generate_templates_openai(template, num_extra_templates, params) + if model_config and model_config.get("provider") == "openai": + return generate_templates_openai(template, num_extra_templates, params) - elif model_config and model_config.get("provider") == "azure": - return generate_templates_azoi(template, num_extra_templates, params) + elif model_config and model_config.get("provider") == "azure": + return generate_templates_azoi(template, num_extra_templates, params) - else: - return generate_templates_openai(template, num_extra_templates) + else: + return generate_templates_openai(template, num_extra_templates) diff --git a/langtest/augmentation/utils.py b/langtest/augmentation/utils.py index a13a8d2e2..ad0051be5 100644 --- a/langtest/augmentation/utils.py +++ b/langtest/augmentation/utils.py @@ -19,15 +19,13 @@ class OpenAIConfig(TypedDict): class AzureOpenAIConfig(TypedDict): """Azure OpenAI Configuration for API Key and Provider.""" - from openai.lib.azure import AzureADTokenProvider - azure_endpoint: str api_version: str api_key: str provider: str azure_deployment: Union[str, None] = None azure_ad_token: Union[str, None] = (None,) - azure_ad_token_provider: Union[AzureADTokenProvider, None] = (None,) + azure_ad_token_provider = (None,) organization: Union[str, None] = (None,) @@ -76,6 +74,7 @@ def generate_templates_azoi( template: str, num_extra_templates: int, model_config: AzureOpenAIConfig ): """Generate new templates based on the provided template using Azure OpenAI API.""" + import openai if "provider" in model_config: @@ -139,6 +138,7 @@ def generate_templates_openai( template: str, num_extra_templates: int, model_config: OpenAIConfig = OpenAIConfig() ): """Generate new templates based on the provided template using OpenAI API.""" + import openai if "provider" in model_config: diff --git a/langtest/datahandler/datasource.py b/langtest/datahandler/datasource.py index 5e35fc97b..4de9999f4 100644 --- a/langtest/datahandler/datasource.py +++ b/langtest/datahandler/datasource.py @@ -172,7 +172,7 @@ class DataFactory: data_sources: Dict[str, BaseDataset] = BaseDataset.data_sources CURATED_BIAS_DATASETS = ["BoolQ", "XSum"] - def __init__(self, file_path: dict, task: TaskManager, **kwargs) -> None: + def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) -> None: """Initializes DataFactory object. Args: @@ -232,6 +232,9 @@ def __init__(self, file_path: dict, task: TaskManager, **kwargs) -> None: self.init_cls: BaseDataset = None self.kwargs = kwargs + if self.task == "ner" and "doc_wise" in self._custom_label: + self.kwargs.update({"doc_wise": self._custom_label.get("doc_wise", False)}) + def load_raw(self): """Loads the data into a raw format""" self.init_cls = self.data_sources[self.file_ext.replace(".", "")]( @@ -257,7 +260,9 @@ def load(self) -> List[Sample]: return DataFactory.load_curated_bias(self._file_path) else: self.init_cls = self.data_sources[self.file_ext.replace(".", "")]( - self._file_path, task=self.task, **self.kwargs + self._file_path, + task=self.task, + **self.kwargs, ) loaded_data = self.init_cls.load_data() @@ -425,7 +430,9 @@ class ConllDataset(BaseDataset): COLUMN_NAMES = {task: COLUMN_MAPPER[task] for task in supported_tasks} - def __init__(self, file_path: str, task: TaskManager) -> None: + def __init__( + self, file_path: Union[str, Dict[str, str]], task: TaskManager, **kwargs + ) -> None: """Initializes ConllDataset object. Args: @@ -434,7 +441,7 @@ def __init__(self, file_path: str, task: TaskManager) -> None: """ super().__init__() self._file_path = file_path - + self.doc_wise = kwargs.get("doc_wise") if "doc_wise" in kwargs else False self.task = task def load_raw_data(self) -> List[Dict]: @@ -495,42 +502,42 @@ def load_data(self) -> List[NERSample]: ] for d_id, doc in enumerate(docs): # file content to sentence split - sentences = re.split(r"\n\n|\n\s+\n", doc.strip()) - - if sentences == [""]: - continue - - for sent in sentences: - # sentence string to token level split - tokens = sent.strip().split("\n") - - # get annotations from token level split - valid_tokens, token_list = self.__token_validation(tokens) - - if not valid_tokens: - logging.warning(Warnings.W004(sent=sent)) - continue - - # get token and labels from the split + if self.doc_wise: + tokens = doc.strip().split("\n") ner_labels = [] cursor = 0 - for split in token_list: - ner_labels.append( - NERPrediction.from_span( - entity=split[-1], - word=split[0], + + for token in tokens: + token_list = token.split() + + if len(token_list) == 0: + pred = NERPrediction.from_span( + entity="", + word="\n", start=cursor, - end=cursor + len(split[0]), - doc_id=d_id, - doc_name=( - docs_strings[d_id] if len(docs_strings) > 0 else "" - ), - pos_tag=split[1], - chunk_tag=split[2], + end=cursor, + pos_tag="", + chunk_tag="", ) - ) - # +1 to account for the white space - cursor += len(split[0]) + 1 + ner_labels.append(pred) + else: + ner_labels.append( + NERPrediction.from_span( + entity=token_list[-1], + word=token_list[0], + start=cursor, + end=cursor + len(token_list[0]), + doc_id=d_id, + doc_name=( + docs_strings[d_id] + if len(docs_strings) > 0 + else "" + ), + pos_tag=token_list[1], + chunk_tag=token_list[2], + ) + ) + cursor += len(token_list[0]) + 1 original = " ".join([label.span.word for label in ner_labels]) @@ -540,6 +547,55 @@ def load_data(self) -> List[NERSample]: expected_results=NEROutput(predictions=ner_labels), ) ) + + else: + sentences = re.split(r"\n\n|\n\s+\n", doc.strip()) + + if sentences == [""]: + continue + + for sent in sentences: + # sentence string to token level split + tokens = sent.strip().split("\n") + + # get annotations from token level split + valid_tokens, token_list = self.__token_validation(tokens) + + if not valid_tokens: + logging.warning(Warnings.W004(sent=sent)) + continue + + # get token and labels from the split + ner_labels = [] + cursor = 0 + for split in token_list: + ner_labels.append( + NERPrediction.from_span( + entity=split[-1], + word=split[0], + start=cursor, + end=cursor + len(split[0]), + doc_id=d_id, + doc_name=( + docs_strings[d_id] + if len(docs_strings) > 0 + else "" + ), + pos_tag=split[1], + chunk_tag=split[2], + ) + ) + # +1 to account for the white space + cursor += len(split[0]) + 1 + + original = " ".join([label.span.word for label in ner_labels]) + + data.append( + self.task.get_sample_class( + original=original, + expected_results=NEROutput(predictions=ner_labels), + ) + ) self.dataset_size = len(data) return data diff --git a/langtest/datahandler/format.py b/langtest/datahandler/format.py index 808c0ade2..e6ba8a459 100644 --- a/langtest/datahandler/format.py +++ b/langtest/datahandler/format.py @@ -195,43 +195,53 @@ def to_conll(sample: NERSample, temp_id: int = None) -> Union[str, Tuple[str, st test_case = sample.test_case original = sample.original if test_case: - test_case_items = test_case.split() - norm_test_case_items = test_case.lower().split() - norm_original_items = original.lower().split() + test_case_items = test_case.split(" ") + norm_test_case_items = test_case.lower().split(" ") + norm_original_items = original.lower().split(" ") temp_len = 0 for jdx, item in enumerate(norm_test_case_items): - try: - if item in norm_original_items and jdx >= norm_original_items.index( - item - ): - oitem_index = norm_original_items.index(item) - j = sample.expected_results.predictions[oitem_index + temp_len] - if temp_id != j.doc_id and jdx == 0: - text += f"{j.doc_name}\n\n" - temp_id = j.doc_id - text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n" - norm_original_items.pop(oitem_index) - temp_len += 1 - else: - o_item = sample.expected_results.predictions[jdx].span.word - letters_count = len(set(item) - set(o_item)) + if test_case_items[jdx] == "\n": + text += "\n" # add a newline character after each sentence + else: + try: if ( - len(norm_test_case_items) == len(original.lower().split()) - or letters_count < 2 + item in norm_original_items + and jdx >= norm_original_items.index(item) ): - tl = sample.expected_results.predictions[jdx] - text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n" + oitem_index = norm_original_items.index(item) + j = sample.expected_results.predictions[ + oitem_index + temp_len + ] + if temp_id != j.doc_id and jdx == 0: + text += f"{j.doc_name}\n\n" + temp_id = j.doc_id + text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n" + norm_original_items.pop(oitem_index) + temp_len += 1 else: - text += f"{test_case_items[jdx]} -X- -X- O\n" - except IndexError: - text += f"{test_case_items[jdx]} -X- -X- O\n" + o_item = sample.expected_results.predictions[jdx].span.word + letters_count = len(set(item) - set(o_item)) + if ( + len(norm_test_case_items) + == len(original.lower().split(" ")) + or letters_count < 2 + ): + tl = sample.expected_results.predictions[jdx] + text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n" + else: + text += f"{test_case_items[jdx]} -X- -X- O\n" + except IndexError: + text += f"{test_case_items[jdx]} -X- -X- O\n" else: for j in sample.expected_results.predictions: - if temp_id != j.doc_id: - text += f"{j.doc_name}\n\n" - temp_id = j.doc_id - text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n" + if j.span.word == "\n": + text += "\n" + else: + if temp_id != j.doc_id: + text += f"{j.doc_name}\n\n" + temp_id = j.doc_id + text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n" return text, temp_id diff --git a/langtest/errors.py b/langtest/errors.py index d3d7d1bba..c4cb90189 100644 --- a/langtest/errors.py +++ b/langtest/errors.py @@ -275,6 +275,7 @@ class Errors(metaclass=ErrorsWithCodes): E094 = ("Unsupported category: '{category}'. Supported categories: {supported_category}") E095 = ("Failed to make API request: {e}") E096 = ("Failed to generate the templates in Augmentation: {msg}") + E097 = ("Failed to load openai. Please install it using `pip install openai`") class ColumnNameError(Exception): diff --git a/langtest/transform/robustness.py b/langtest/transform/robustness.py index 5172c5135..ac3ed4fd7 100644 --- a/langtest/transform/robustness.py +++ b/langtest/transform/robustness.py @@ -378,7 +378,7 @@ def transform(sample_list: List[Sample], prob: Optional[float] = 1.0) -> List[Sa """ for idx, sample in enumerate(sample_list): if isinstance(sample, str): - words = sample.split() + words = sample.split(" ") num_transform_words = int(prob * len(words)) transformed_indices = random.sample( range(len(words)), num_transform_words @@ -389,7 +389,7 @@ def transform(sample_list: List[Sample], prob: Optional[float] = 1.0) -> List[Sa ] sample_list[idx] = " ".join(transformed_words) else: - words = sample.original.split() + words = sample.original.split(" ") num_transform_words = int(prob * len(words)) transformed_indices = random.sample( range(len(words)), num_transform_words @@ -422,7 +422,7 @@ def transform(sample_list: List[Sample], prob: Optional[float] = 1.0) -> List[Sa """ for idx, sample in enumerate(sample_list): if isinstance(sample, str): - words = sample.split() + words = sample.split(" ") num_transform_words = int(prob * len(words)) transformed_indices = random.sample( range(len(words)), num_transform_words @@ -433,7 +433,7 @@ def transform(sample_list: List[Sample], prob: Optional[float] = 1.0) -> List[Sa ] sample_list[idx] = " ".join(transformed_words) else: - words = sample.original.split() + words = sample.original.split(" ") num_transform_words = int(prob * len(words)) transformed_indices = random.sample( range(len(words)), num_transform_words @@ -466,7 +466,7 @@ def transform(sample_list: List[Sample], prob: Optional[float] = 1.0) -> List[Sa """ for idx, sample in enumerate(sample_list): if isinstance(sample, str): - words = sample.split() + words = sample.split(" ") num_transform_words = int(prob * len(words)) transformed_indices = random.sample( range(len(words)), num_transform_words @@ -477,7 +477,7 @@ def transform(sample_list: List[Sample], prob: Optional[float] = 1.0) -> List[Sa ] sample_list[idx] = " ".join(transformed_words) else: - words = sample.original.split() + words = sample.original.split(" ") num_transform_words = int(prob * len(words)) transformed_indices = random.sample( range(len(words)), num_transform_words