Skip to content

Commit

Permalink
Merge pull request #1094 from JohnSnowLabs/enhance/document-wise-data…
Browse files Browse the repository at this point in the history
…-processing-in-conlldataset

implemented: basic version to handling document wise.
  • Loading branch information
chakravarthik27 authored Sep 11, 2024
2 parents 97bee72 + 3927b24 commit f1fbdc1
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 66 deletions.
128 changes: 92 additions & 36 deletions langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class DataFactory:
data_sources: Dict[str, BaseDataset] = BaseDataset.data_sources
CURATED_BIAS_DATASETS = ["BoolQ", "XSum"]

def __init__(self, file_path: dict, task: TaskManager, **kwargs) -> None:
def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) -> None:
"""Initializes DataFactory object.
Args:
Expand Down Expand Up @@ -232,6 +232,9 @@ def __init__(self, file_path: dict, task: TaskManager, **kwargs) -> None:
self.init_cls: BaseDataset = None
self.kwargs = kwargs

if self.task == "ner" and "doc_wise" in self._custom_label:
self.kwargs.update({"doc_wise": self._custom_label.get("doc_wise", False)})

def load_raw(self):
"""Loads the data into a raw format"""
self.init_cls = self.data_sources[self.file_ext.replace(".", "")](
Expand All @@ -257,7 +260,9 @@ def load(self) -> List[Sample]:
return DataFactory.load_curated_bias(self._file_path)
else:
self.init_cls = self.data_sources[self.file_ext.replace(".", "")](
self._file_path, task=self.task, **self.kwargs
self._file_path,
task=self.task,
**self.kwargs,
)

loaded_data = self.init_cls.load_data()
Expand Down Expand Up @@ -425,7 +430,9 @@ class ConllDataset(BaseDataset):

COLUMN_NAMES = {task: COLUMN_MAPPER[task] for task in supported_tasks}

def __init__(self, file_path: str, task: TaskManager) -> None:
def __init__(
self, file_path: Union[str, Dict[str, str]], task: TaskManager, **kwargs
) -> None:
"""Initializes ConllDataset object.
Args:
Expand All @@ -434,7 +441,7 @@ def __init__(self, file_path: str, task: TaskManager) -> None:
"""
super().__init__()
self._file_path = file_path

self.doc_wise = kwargs.get("doc_wise") if "doc_wise" in kwargs else False
self.task = task

def load_raw_data(self) -> List[Dict]:
Expand Down Expand Up @@ -495,42 +502,42 @@ def load_data(self) -> List[NERSample]:
]
for d_id, doc in enumerate(docs):
# file content to sentence split
sentences = re.split(r"\n\n|\n\s+\n", doc.strip())

if sentences == [""]:
continue

for sent in sentences:
# sentence string to token level split
tokens = sent.strip().split("\n")

# get annotations from token level split
valid_tokens, token_list = self.__token_validation(tokens)

if not valid_tokens:
logging.warning(Warnings.W004(sent=sent))
continue

# get token and labels from the split
if self.doc_wise:
tokens = doc.strip().split("\n")
ner_labels = []
cursor = 0
for split in token_list:
ner_labels.append(
NERPrediction.from_span(
entity=split[-1],
word=split[0],

for token in tokens:
token_list = token.split()

if len(token_list) == 0:
pred = NERPrediction.from_span(
entity="",
word="\n",
start=cursor,
end=cursor + len(split[0]),
doc_id=d_id,
doc_name=(
docs_strings[d_id] if len(docs_strings) > 0 else ""
),
pos_tag=split[1],
chunk_tag=split[2],
end=cursor,
pos_tag="",
chunk_tag="",
)
)
# +1 to account for the white space
cursor += len(split[0]) + 1
ner_labels.append(pred)
else:
ner_labels.append(
NERPrediction.from_span(
entity=token_list[-1],
word=token_list[0],
start=cursor,
end=cursor + len(token_list[0]),
doc_id=d_id,
doc_name=(
docs_strings[d_id]
if len(docs_strings) > 0
else ""
),
pos_tag=token_list[1],
chunk_tag=token_list[2],
)
)
cursor += len(token_list[0]) + 1

original = " ".join([label.span.word for label in ner_labels])

Expand All @@ -540,6 +547,55 @@ def load_data(self) -> List[NERSample]:
expected_results=NEROutput(predictions=ner_labels),
)
)

else:
sentences = re.split(r"\n\n|\n\s+\n", doc.strip())

if sentences == [""]:
continue

for sent in sentences:
# sentence string to token level split
tokens = sent.strip().split("\n")

# get annotations from token level split
valid_tokens, token_list = self.__token_validation(tokens)

if not valid_tokens:
logging.warning(Warnings.W004(sent=sent))
continue

# get token and labels from the split
ner_labels = []
cursor = 0
for split in token_list:
ner_labels.append(
NERPrediction.from_span(
entity=split[-1],
word=split[0],
start=cursor,
end=cursor + len(split[0]),
doc_id=d_id,
doc_name=(
docs_strings[d_id]
if len(docs_strings) > 0
else ""
),
pos_tag=split[1],
chunk_tag=split[2],
)
)
# +1 to account for the white space
cursor += len(split[0]) + 1

original = " ".join([label.span.word for label in ner_labels])

data.append(
self.task.get_sample_class(
original=original,
expected_results=NEROutput(predictions=ner_labels),
)
)
self.dataset_size = len(data)
return data

Expand Down
68 changes: 39 additions & 29 deletions langtest/datahandler/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,43 +195,53 @@ def to_conll(sample: NERSample, temp_id: int = None) -> Union[str, Tuple[str, st
test_case = sample.test_case
original = sample.original
if test_case:
test_case_items = test_case.split()
norm_test_case_items = test_case.lower().split()
norm_original_items = original.lower().split()
test_case_items = test_case.split(" ")
norm_test_case_items = test_case.lower().split(" ")
norm_original_items = original.lower().split(" ")
temp_len = 0
for jdx, item in enumerate(norm_test_case_items):
try:
if item in norm_original_items and jdx >= norm_original_items.index(
item
):
oitem_index = norm_original_items.index(item)
j = sample.expected_results.predictions[oitem_index + temp_len]
if temp_id != j.doc_id and jdx == 0:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
norm_original_items.pop(oitem_index)
temp_len += 1
else:
o_item = sample.expected_results.predictions[jdx].span.word
letters_count = len(set(item) - set(o_item))
if test_case_items[jdx] == "\n":
text += "\n" # add a newline character after each sentence
else:
try:
if (
len(norm_test_case_items) == len(original.lower().split())
or letters_count < 2
item in norm_original_items
and jdx >= norm_original_items.index(item)
):
tl = sample.expected_results.predictions[jdx]
text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n"
oitem_index = norm_original_items.index(item)
j = sample.expected_results.predictions[
oitem_index + temp_len
]
if temp_id != j.doc_id and jdx == 0:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
norm_original_items.pop(oitem_index)
temp_len += 1
else:
text += f"{test_case_items[jdx]} -X- -X- O\n"
except IndexError:
text += f"{test_case_items[jdx]} -X- -X- O\n"
o_item = sample.expected_results.predictions[jdx].span.word
letters_count = len(set(item) - set(o_item))
if (
len(norm_test_case_items)
== len(original.lower().split(" "))
or letters_count < 2
):
tl = sample.expected_results.predictions[jdx]
text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n"
else:
text += f"{test_case_items[jdx]} -X- -X- O\n"
except IndexError:
text += f"{test_case_items[jdx]} -X- -X- O\n"

else:
for j in sample.expected_results.predictions:
if temp_id != j.doc_id:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
if j.span.word == "\n":
text += "\n"
else:
if temp_id != j.doc_id:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n"

return text, temp_id

Expand Down
2 changes: 1 addition & 1 deletion langtest/transform/robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def transform(sample_list: List[Sample], prob: Optional[float] = 1.0) -> List[Sa
]
sample_list[idx] = " ".join(transformed_words)
else:
words = sample.original.split()
words = sample.original.split(" ")
num_transform_words = int(prob * len(words))
transformed_indices = random.sample(
range(len(words)), num_transform_words
Expand Down

0 comments on commit f1fbdc1

Please sign in to comment.