Skip to content

Commit

Permalink
Merge pull request #1103 from JohnSnowLabs/patch/2.3.1
Browse files Browse the repository at this point in the history
Patch/2.3.1
  • Loading branch information
chakravarthik27 authored Sep 11, 2024
2 parents 78cb31f + 134de82 commit b35c28a
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 87 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
pull_request:
branches:
- "release/*"
- "patch/*"
- "main"

jobs:
Expand Down
28 changes: 15 additions & 13 deletions langtest/augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from langtest.utils.custom_types.predictions import NERPrediction, SequenceLabel
from langtest.utils.custom_types.sample import NERSample
from langtest.tasks import TaskManager
from ..utils.lib_manager import try_import_lib
from ..errors import Errors


Expand Down Expand Up @@ -358,6 +357,9 @@ def __init__(
# Extend the existing templates list

self.__templates.extend(generated_templates[:num_extra_templates])
except ModuleNotFoundError:
raise ImportError(Errors.E097())

except Exception as e_msg:
raise Errors.E095(e=e_msg)

Expand Down Expand Up @@ -606,19 +608,19 @@ def __generate_templates(
num_extra_templates: int,
model_config: Union[OpenAIConfig, AzureOpenAIConfig] = None,
) -> List[str]:
if try_import_lib("openai"):
from langtest.augmentation.utils import (
generate_templates_azoi, # azoi means Azure OpenAI
generate_templates_openai,
)
"""This method is used to generate extra templates from a given template."""
from langtest.augmentation.utils import (
generate_templates_azoi, # azoi means Azure OpenAI
generate_templates_openai,
)

params = model_config.copy() if model_config else {}
params = model_config.copy() if model_config else {}

if model_config and model_config.get("provider") == "openai":
return generate_templates_openai(template, num_extra_templates, params)
if model_config and model_config.get("provider") == "openai":
return generate_templates_openai(template, num_extra_templates, params)

elif model_config and model_config.get("provider") == "azure":
return generate_templates_azoi(template, num_extra_templates, params)
elif model_config and model_config.get("provider") == "azure":
return generate_templates_azoi(template, num_extra_templates, params)

else:
return generate_templates_openai(template, num_extra_templates)
else:
return generate_templates_openai(template, num_extra_templates)
6 changes: 3 additions & 3 deletions langtest/augmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@ class OpenAIConfig(TypedDict):
class AzureOpenAIConfig(TypedDict):
"""Azure OpenAI Configuration for API Key and Provider."""

from openai.lib.azure import AzureADTokenProvider

azure_endpoint: str
api_version: str
api_key: str
provider: str
azure_deployment: Union[str, None] = None
azure_ad_token: Union[str, None] = (None,)
azure_ad_token_provider: Union[AzureADTokenProvider, None] = (None,)
azure_ad_token_provider = (None,)
organization: Union[str, None] = (None,)


Expand Down Expand Up @@ -76,6 +74,7 @@ def generate_templates_azoi(
template: str, num_extra_templates: int, model_config: AzureOpenAIConfig
):
"""Generate new templates based on the provided template using Azure OpenAI API."""

import openai

if "provider" in model_config:
Expand Down Expand Up @@ -139,6 +138,7 @@ def generate_templates_openai(
template: str, num_extra_templates: int, model_config: OpenAIConfig = OpenAIConfig()
):
"""Generate new templates based on the provided template using OpenAI API."""

import openai

if "provider" in model_config:
Expand Down
128 changes: 92 additions & 36 deletions langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class DataFactory:
data_sources: Dict[str, BaseDataset] = BaseDataset.data_sources
CURATED_BIAS_DATASETS = ["BoolQ", "XSum"]

def __init__(self, file_path: dict, task: TaskManager, **kwargs) -> None:
def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) -> None:
"""Initializes DataFactory object.
Args:
Expand Down Expand Up @@ -232,6 +232,9 @@ def __init__(self, file_path: dict, task: TaskManager, **kwargs) -> None:
self.init_cls: BaseDataset = None
self.kwargs = kwargs

if self.task == "ner" and "doc_wise" in self._custom_label:
self.kwargs.update({"doc_wise": self._custom_label.get("doc_wise", False)})

def load_raw(self):
"""Loads the data into a raw format"""
self.init_cls = self.data_sources[self.file_ext.replace(".", "")](
Expand All @@ -257,7 +260,9 @@ def load(self) -> List[Sample]:
return DataFactory.load_curated_bias(self._file_path)
else:
self.init_cls = self.data_sources[self.file_ext.replace(".", "")](
self._file_path, task=self.task, **self.kwargs
self._file_path,
task=self.task,
**self.kwargs,
)

loaded_data = self.init_cls.load_data()
Expand Down Expand Up @@ -425,7 +430,9 @@ class ConllDataset(BaseDataset):

COLUMN_NAMES = {task: COLUMN_MAPPER[task] for task in supported_tasks}

def __init__(self, file_path: str, task: TaskManager) -> None:
def __init__(
self, file_path: Union[str, Dict[str, str]], task: TaskManager, **kwargs
) -> None:
"""Initializes ConllDataset object.
Args:
Expand All @@ -434,7 +441,7 @@ def __init__(self, file_path: str, task: TaskManager) -> None:
"""
super().__init__()
self._file_path = file_path

self.doc_wise = kwargs.get("doc_wise") if "doc_wise" in kwargs else False
self.task = task

def load_raw_data(self) -> List[Dict]:
Expand Down Expand Up @@ -495,42 +502,42 @@ def load_data(self) -> List[NERSample]:
]
for d_id, doc in enumerate(docs):
# file content to sentence split
sentences = re.split(r"\n\n|\n\s+\n", doc.strip())

if sentences == [""]:
continue

for sent in sentences:
# sentence string to token level split
tokens = sent.strip().split("\n")

# get annotations from token level split
valid_tokens, token_list = self.__token_validation(tokens)

if not valid_tokens:
logging.warning(Warnings.W004(sent=sent))
continue

# get token and labels from the split
if self.doc_wise:
tokens = doc.strip().split("\n")
ner_labels = []
cursor = 0
for split in token_list:
ner_labels.append(
NERPrediction.from_span(
entity=split[-1],
word=split[0],

for token in tokens:
token_list = token.split()

if len(token_list) == 0:
pred = NERPrediction.from_span(
entity="",
word="\n",
start=cursor,
end=cursor + len(split[0]),
doc_id=d_id,
doc_name=(
docs_strings[d_id] if len(docs_strings) > 0 else ""
),
pos_tag=split[1],
chunk_tag=split[2],
end=cursor,
pos_tag="",
chunk_tag="",
)
)
# +1 to account for the white space
cursor += len(split[0]) + 1
ner_labels.append(pred)
else:
ner_labels.append(
NERPrediction.from_span(
entity=token_list[-1],
word=token_list[0],
start=cursor,
end=cursor + len(token_list[0]),
doc_id=d_id,
doc_name=(
docs_strings[d_id]
if len(docs_strings) > 0
else ""
),
pos_tag=token_list[1],
chunk_tag=token_list[2],
)
)
cursor += len(token_list[0]) + 1

original = " ".join([label.span.word for label in ner_labels])

Expand All @@ -540,6 +547,55 @@ def load_data(self) -> List[NERSample]:
expected_results=NEROutput(predictions=ner_labels),
)
)

else:
sentences = re.split(r"\n\n|\n\s+\n", doc.strip())

if sentences == [""]:
continue

for sent in sentences:
# sentence string to token level split
tokens = sent.strip().split("\n")

# get annotations from token level split
valid_tokens, token_list = self.__token_validation(tokens)

if not valid_tokens:
logging.warning(Warnings.W004(sent=sent))
continue

# get token and labels from the split
ner_labels = []
cursor = 0
for split in token_list:
ner_labels.append(
NERPrediction.from_span(
entity=split[-1],
word=split[0],
start=cursor,
end=cursor + len(split[0]),
doc_id=d_id,
doc_name=(
docs_strings[d_id]
if len(docs_strings) > 0
else ""
),
pos_tag=split[1],
chunk_tag=split[2],
)
)
# +1 to account for the white space
cursor += len(split[0]) + 1

original = " ".join([label.span.word for label in ner_labels])

data.append(
self.task.get_sample_class(
original=original,
expected_results=NEROutput(predictions=ner_labels),
)
)
self.dataset_size = len(data)
return data

Expand Down
68 changes: 39 additions & 29 deletions langtest/datahandler/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,43 +195,53 @@ def to_conll(sample: NERSample, temp_id: int = None) -> Union[str, Tuple[str, st
test_case = sample.test_case
original = sample.original
if test_case:
test_case_items = test_case.split()
norm_test_case_items = test_case.lower().split()
norm_original_items = original.lower().split()
test_case_items = test_case.split(" ")
norm_test_case_items = test_case.lower().split(" ")
norm_original_items = original.lower().split(" ")
temp_len = 0
for jdx, item in enumerate(norm_test_case_items):
try:
if item in norm_original_items and jdx >= norm_original_items.index(
item
):
oitem_index = norm_original_items.index(item)
j = sample.expected_results.predictions[oitem_index + temp_len]
if temp_id != j.doc_id and jdx == 0:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
norm_original_items.pop(oitem_index)
temp_len += 1
else:
o_item = sample.expected_results.predictions[jdx].span.word
letters_count = len(set(item) - set(o_item))
if test_case_items[jdx] == "\n":
text += "\n" # add a newline character after each sentence
else:
try:
if (
len(norm_test_case_items) == len(original.lower().split())
or letters_count < 2
item in norm_original_items
and jdx >= norm_original_items.index(item)
):
tl = sample.expected_results.predictions[jdx]
text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n"
oitem_index = norm_original_items.index(item)
j = sample.expected_results.predictions[
oitem_index + temp_len
]
if temp_id != j.doc_id and jdx == 0:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{test_case_items[jdx]} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
norm_original_items.pop(oitem_index)
temp_len += 1
else:
text += f"{test_case_items[jdx]} -X- -X- O\n"
except IndexError:
text += f"{test_case_items[jdx]} -X- -X- O\n"
o_item = sample.expected_results.predictions[jdx].span.word
letters_count = len(set(item) - set(o_item))
if (
len(norm_test_case_items)
== len(original.lower().split(" "))
or letters_count < 2
):
tl = sample.expected_results.predictions[jdx]
text += f"{test_case_items[jdx]} {tl.pos_tag} {tl.chunk_tag} {tl.entity}\n"
else:
text += f"{test_case_items[jdx]} -X- -X- O\n"
except IndexError:
text += f"{test_case_items[jdx]} -X- -X- O\n"

else:
for j in sample.expected_results.predictions:
if temp_id != j.doc_id:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n"
if j.span.word == "\n":
text += "\n"
else:
if temp_id != j.doc_id:
text += f"{j.doc_name}\n\n"
temp_id = j.doc_id
text += f"{j.span.word} {j.pos_tag} {j.chunk_tag} {j.entity}\n"

return text, temp_id

Expand Down
1 change: 1 addition & 0 deletions langtest/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ class Errors(metaclass=ErrorsWithCodes):
E094 = ("Unsupported category: '{category}'. Supported categories: {supported_category}")
E095 = ("Failed to make API request: {e}")
E096 = ("Failed to generate the templates in Augmentation: {msg}")
E097 = ("Failed to load openai. Please install it using `pip install openai`")


class ColumnNameError(Exception):
Expand Down
Loading

0 comments on commit b35c28a

Please sign in to comment.