From 688723304d37db50c48fe8de4b07226799189166 Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Mon, 2 Sep 2024 21:42:20 +0530 Subject: [PATCH 1/8] implemented: basic version to handling document wise. --- langtest/datahandler/datasource.py | 100 +++++++++++++++++++++-------- 1 file changed, 72 insertions(+), 28 deletions(-) diff --git a/langtest/datahandler/datasource.py b/langtest/datahandler/datasource.py index 51071334f..8d1f3188d 100644 --- a/langtest/datahandler/datasource.py +++ b/langtest/datahandler/datasource.py @@ -171,7 +171,7 @@ class DataFactory: data_sources: Dict[str, BaseDataset] = BaseDataset.data_sources CURATED_BIAS_DATASETS = ["BoolQ", "XSum"] - def __init__(self, file_path: dict, task: TaskManager, **kwargs) -> None: + def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) -> None: """Initializes DataFactory object. Args: @@ -230,6 +230,7 @@ def __init__(self, file_path: dict, task: TaskManager, **kwargs) -> None: self.task = task self.init_cls: BaseDataset = None self.kwargs = kwargs + self.kwargs.update({"doc_wise": self._custom_label.get("doc_wise", False)}) def load_raw(self): """Loads the data into a raw format""" @@ -256,7 +257,9 @@ def load(self) -> List[Sample]: return DataFactory.load_curated_bias(self._file_path) else: self.init_cls = self.data_sources[self.file_ext.replace(".", "")]( - self._file_path, task=self.task, **self.kwargs + self._file_path, + task=self.task, + **self.kwargs, ) loaded_data = self.init_cls.load_data() @@ -424,7 +427,9 @@ class ConllDataset(BaseDataset): COLUMN_NAMES = {task: COLUMN_MAPPER[task] for task in supported_tasks} - def __init__(self, file_path: str, task: TaskManager) -> None: + def __init__( + self, file_path: Union[str, Dict[str, str]], task: TaskManager, **kwargs + ) -> None: """Initializes ConllDataset object. Args: @@ -432,8 +437,9 @@ def __init__(self, file_path: str, task: TaskManager) -> None: task (str): name of the task to perform """ super().__init__() + print(kwargs, file_path) self._file_path = file_path - + self.doc_wise = kwargs.get("doc_wise") if "doc_wise" in kwargs else False self.task = task def load_raw_data(self) -> List[Dict]: @@ -494,42 +500,31 @@ def load_data(self) -> List[NERSample]: ] for d_id, doc in enumerate(docs): # file content to sentence split - sentences = re.split(r"\n\n|\n\s+\n", doc.strip()) - - if sentences == [""]: - continue - - for sent in sentences: - # sentence string to token level split - tokens = sent.strip().split("\n") + if self.doc_wise: + tokens = doc.strip().split("\n") + ner_labels = [] + cursor = 0 - # get annotations from token level split - valid_tokens, token_list = self.__token_validation(tokens) + for token in tokens: + token_list = token.split() - if not valid_tokens: - logging.warning(Warnings.W004(sent=sent)) - continue + print(token, token_list) + if len(token_list) == 0: + continue - # get token and labels from the split - ner_labels = [] - cursor = 0 - for split in token_list: ner_labels.append( NERPrediction.from_span( - entity=split[-1], - word=split[0], + entity=token_list[-1], + word=token_list[0], start=cursor, - end=cursor + len(split[0]), + end=cursor + len(token_list[0]), doc_id=d_id, doc_name=( docs_strings[d_id] if len(docs_strings) > 0 else "" ), - pos_tag=split[1], - chunk_tag=split[2], ) ) - # +1 to account for the white space - cursor += len(split[0]) + 1 + cursor += len(token_list[0]) + 1 original = " ".join([label.span.word for label in ner_labels]) @@ -539,6 +534,55 @@ def load_data(self) -> List[NERSample]: expected_results=NEROutput(predictions=ner_labels), ) ) + + else: + sentences = re.split(r"\n\n|\n\s+\n", doc.strip()) + + if sentences == [""]: + continue + + for sent in sentences: + # sentence string to token level split + tokens = sent.strip().split("\n") + + # get annotations from token level split + valid_tokens, token_list = self.__token_validation(tokens) + + if not valid_tokens: + logging.warning(Warnings.W004(sent=sent)) + continue + + # get token and labels from the split + ner_labels = [] + cursor = 0 + for split in token_list: + ner_labels.append( + NERPrediction.from_span( + entity=split[-1], + word=split[0], + start=cursor, + end=cursor + len(split[0]), + doc_id=d_id, + doc_name=( + docs_strings[d_id] + if len(docs_strings) > 0 + else "" + ), + pos_tag=split[1], + chunk_tag=split[2], + ) + ) + # +1 to account for the white space + cursor += len(split[0]) + 1 + + original = " ".join([label.span.word for label in ner_labels]) + + data.append( + self.task.get_sample_class( + original=original, + expected_results=NEROutput(predictions=ner_labels), + ) + ) self.dataset_size = len(data) return data From beec9c3d5d8ebfa2c9742d3c74e73fff5f317811 Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Tue, 3 Sep 2024 13:35:25 +0530 Subject: [PATCH 2/8] feat: Add pos_tag and chunk_tag to ConllDataset token creation in doc_wise. --- langtest/datahandler/datasource.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/langtest/datahandler/datasource.py b/langtest/datahandler/datasource.py index 8d1f3188d..702fa2062 100644 --- a/langtest/datahandler/datasource.py +++ b/langtest/datahandler/datasource.py @@ -522,6 +522,8 @@ def load_data(self) -> List[NERSample]: doc_name=( docs_strings[d_id] if len(docs_strings) > 0 else "" ), + pos_tag=token_list[1], + chunk_tag=token_list[2], ) ) cursor += len(token_list[0]) + 1 From 2eb72a3cdb15b29ba596c7163c8a535de47a322b Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Tue, 10 Sep 2024 14:14:59 +0530 Subject: [PATCH 3/8] Refactor ConllDataset token creation in doc_wise --- langtest/datahandler/datasource.py | 40 +++++++++++++-------- langtest/datahandler/format.py | 57 +++++++++++++++++------------- langtest/transform/robustness.py | 2 +- 3 files changed, 58 insertions(+), 41 deletions(-) diff --git a/langtest/datahandler/datasource.py b/langtest/datahandler/datasource.py index 702fa2062..8be470b52 100644 --- a/langtest/datahandler/datasource.py +++ b/langtest/datahandler/datasource.py @@ -510,23 +510,33 @@ def load_data(self) -> List[NERSample]: print(token, token_list) if len(token_list) == 0: - continue - - ner_labels.append( - NERPrediction.from_span( - entity=token_list[-1], - word=token_list[0], + pred = NERPrediction.from_span( + entity="", + word="\n", start=cursor, - end=cursor + len(token_list[0]), - doc_id=d_id, - doc_name=( - docs_strings[d_id] if len(docs_strings) > 0 else "" - ), - pos_tag=token_list[1], - chunk_tag=token_list[2], + end=cursor, + pos_tag="", + chunk_tag="", ) - ) - cursor += len(token_list[0]) + 1 + ner_labels.append(pred) + else: + ner_labels.append( + NERPrediction.from_span( + entity=token_list[-1], + word=token_list[0], + start=cursor, + end=cursor + len(token_list[0]), + doc_id=d_id, + doc_name=( + docs_strings[d_id] + if len(docs_strings) > 0 + else "" + ), + pos_tag=token_list[1], + chunk_tag=token_list[2], + ) + ) + cursor += len(token_list[0]) + 1 original = " ".join([label.span.word for label in ner_labels]) diff --git a/langtest/datahandler/format.py b/langtest/datahandler/format.py index 0755108f0..f62ca71fc 100644 --- a/langtest/datahandler/format.py +++ b/langtest/datahandler/format.py @@ -186,36 +186,43 @@ def to_conll(sample: NERSample, temp_id: int = None) -> Union[str, Tuple[str, st test_case = sample.test_case original = sample.original if test_case: - test_case_items = test_case.split() - norm_test_case_items = test_case.lower().split() - norm_original_items = original.lower().split() + test_case_items = test_case.split(" ") + norm_test_case_items = test_case.lower().split(" ") + norm_original_items = original.lower().split(" ") temp_len = 0 for jdx, item in enumerate(norm_test_case_items): - try: - if item in norm_original_items and jdx >= norm_original_items.index( - item - ): - oitem_index = norm_original_items.index(item) - j = sample.expected_results.predictions[oitem_index + temp_len] - if temp_id != j.doc_id and jdx == 0: - text += f"{j.doc_name}\n\n" - temp_id = j.doc_id - text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n" - norm_original_items.pop(oitem_index) - temp_len += 1 - else: - o_item = sample.expected_results.predictions[jdx].span.word - letters_count = len(set(item) - set(o_item)) + if test_case_items[jdx] == "\n": + text += "\n" + else: + try: if ( - len(norm_test_case_items) == len(original.lower().split()) - or letters_count < 2 + item in norm_original_items + and jdx >= norm_original_items.index(item) ): - tl = sample.expected_results.predictions[jdx] - text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n" + oitem_index = norm_original_items.index(item) + j = sample.expected_results.predictions[ + oitem_index + temp_len + ] + if temp_id != j.doc_id and jdx == 0: + text += f"{j.doc_name}\n\n" + temp_id = j.doc_id + text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n" + norm_original_items.pop(oitem_index) + temp_len += 1 else: - text += f"{test_case_items[jdx]} -X- -X- O\n" - except IndexError: - text += f"{test_case_items[jdx]} -X- -X- O\n" + o_item = sample.expected_results.predictions[jdx].span.word + letters_count = len(set(item) - set(o_item)) + if ( + len(norm_test_case_items) + == len(original.lower().split(" ")) + or letters_count < 2 + ): + tl = sample.expected_results.predictions[jdx] + text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n" + else: + text += f"{test_case_items[jdx]} -X- -X- O\n" + except IndexError: + text += f"{test_case_items[jdx]} -X- -X- O\n" else: for j in sample.expected_results.predictions: diff --git a/langtest/transform/robustness.py b/langtest/transform/robustness.py index 5172c5135..d0ae9fc19 100644 --- a/langtest/transform/robustness.py +++ b/langtest/transform/robustness.py @@ -389,7 +389,7 @@ def transform(sample_list: List[Sample], prob: Optional[float] = 1.0) -> List[Sa ] sample_list[idx] = " ".join(transformed_words) else: - words = sample.original.split() + words = sample.original.split(" ") num_transform_words = int(prob * len(words)) transformed_indices = random.sample( range(len(words)), num_transform_words From aef135256296b12dd2bfcc2aa6268546e165f4e4 Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Tue, 10 Sep 2024 14:44:09 +0530 Subject: [PATCH 4/8] Refactor NEROutputFormatter to add newline character after each sentence --- langtest/datahandler/format.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langtest/datahandler/format.py b/langtest/datahandler/format.py index f62ca71fc..aea8086be 100644 --- a/langtest/datahandler/format.py +++ b/langtest/datahandler/format.py @@ -192,7 +192,7 @@ def to_conll(sample: NERSample, temp_id: int = None) -> Union[str, Tuple[str, st temp_len = 0 for jdx, item in enumerate(norm_test_case_items): if test_case_items[jdx] == "\n": - text += "\n" + text += "\n" # add a newline character after each sentence else: try: if ( From 6a8aae387c9804451c2e0e63788bdd3cce5a9f7b Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Tue, 10 Sep 2024 15:13:47 +0530 Subject: [PATCH 5/8] fixed: linting issues --- langtest/datahandler/format.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langtest/datahandler/format.py b/langtest/datahandler/format.py index aea8086be..3d881df8e 100644 --- a/langtest/datahandler/format.py +++ b/langtest/datahandler/format.py @@ -192,7 +192,7 @@ def to_conll(sample: NERSample, temp_id: int = None) -> Union[str, Tuple[str, st temp_len = 0 for jdx, item in enumerate(norm_test_case_items): if test_case_items[jdx] == "\n": - text += "\n" # add a newline character after each sentence + text += "\n" # add a newline character after each sentence else: try: if ( From 3111017084f5133da2eededf0860c60b46c1462c Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Tue, 10 Sep 2024 15:45:23 +0530 Subject: [PATCH 6/8] Refactor NEROutputFormatter to handle newline characters in sample predictions --- langtest/datahandler/format.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/langtest/datahandler/format.py b/langtest/datahandler/format.py index 3d881df8e..43f85f934 100644 --- a/langtest/datahandler/format.py +++ b/langtest/datahandler/format.py @@ -226,10 +226,13 @@ def to_conll(sample: NERSample, temp_id: int = None) -> Union[str, Tuple[str, st else: for j in sample.expected_results.predictions: - if temp_id != j.doc_id: - text += f"{j.doc_name}\n\n" - temp_id = j.doc_id - text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n" + if j.span.word == "\n": + text += "\n" + else: + if temp_id != j.doc_id: + text += f"{j.doc_name}\n\n" + temp_id = j.doc_id + text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n" return text, temp_id From 675941788aa410bd62018ae795f22f6d28ef135a Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Tue, 10 Sep 2024 16:16:10 +0530 Subject: [PATCH 7/8] fixed: issue with `doc_wise` parameter for another task. --- langtest/datahandler/datasource.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/langtest/datahandler/datasource.py b/langtest/datahandler/datasource.py index 8be470b52..06d6376be 100644 --- a/langtest/datahandler/datasource.py +++ b/langtest/datahandler/datasource.py @@ -230,7 +230,9 @@ def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) -> self.task = task self.init_cls: BaseDataset = None self.kwargs = kwargs - self.kwargs.update({"doc_wise": self._custom_label.get("doc_wise", False)}) + + if self.task == "ner": + self.kwargs.update({"doc_wise": self._custom_label.get("doc_wise", False)}) def load_raw(self): """Loads the data into a raw format""" @@ -437,7 +439,6 @@ def __init__( task (str): name of the task to perform """ super().__init__() - print(kwargs, file_path) self._file_path = file_path self.doc_wise = kwargs.get("doc_wise") if "doc_wise" in kwargs else False self.task = task @@ -508,7 +509,6 @@ def load_data(self) -> List[NERSample]: for token in tokens: token_list = token.split() - print(token, token_list) if len(token_list) == 0: pred = NERPrediction.from_span( entity="", From 84040b60e5a2d6d9b866914e25736906b4a48540 Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Wed, 11 Sep 2024 10:33:30 +0530 Subject: [PATCH 8/8] fixed: doc_wise issue in harness import_testcases method. --- langtest/datahandler/datasource.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langtest/datahandler/datasource.py b/langtest/datahandler/datasource.py index 06d6376be..71424a6f5 100644 --- a/langtest/datahandler/datasource.py +++ b/langtest/datahandler/datasource.py @@ -231,7 +231,7 @@ def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) -> self.init_cls: BaseDataset = None self.kwargs = kwargs - if self.task == "ner": + if self.task == "ner" and "doc_wise" in self._custom_label: self.kwargs.update({"doc_wise": self._custom_label.get("doc_wise", False)}) def load_raw(self):