Skip to content

Commit

Permalink
Merge pull request #1101 from JohnSnowLabs/patch/2.3.1
Browse files Browse the repository at this point in the history
Patch/2.3.1
  • Loading branch information
chakravarthik27 authored Sep 9, 2024
2 parents 5b1c284 + 97bee72 commit 78cb31f
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 11 deletions.
20 changes: 17 additions & 3 deletions langtest/augmentation/augmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
from langtest.transform import TestFactory
from langtest.tasks.task import TaskManager
from langtest.utils.custom_types.sample import Sample
from langtest.logger import logger


class DataAugmenter:
def __init__(self, task: Union[str, TaskManager], config: Union[str, dict]) -> None:
def __init__(
self,
task: Union[str, TaskManager],
config: Union[str, dict],
) -> None:
"""
Initialize the DataAugmenter.
Expand Down Expand Up @@ -241,11 +246,20 @@ def prepare_hash_map(

return hashmap

def save(self, file_path: str):
def save(self, file_path: str, for_gen_ai=False) -> None:
"""
Save the augmented data.
"""
self.__datafactory.export(data=self.__augmented_data, output_path=file_path)
try:
# .json file allow only for_gen_ai boolean is true and task is ner
# then file_path should be .json
if not (for_gen_ai) and self.__task.task_name == "ner":
if file_path.endswith(".json"):
raise ValueError("File path shouldn't be .json file")

self.__datafactory.export(data=self.__augmented_data, output_path=file_path)
except Exception as e:
logger.error(f"Error in saving the augmented data: {e}")

def __or__(self, other: Iterable):
results = self.augment(other)
Expand Down
52 changes: 44 additions & 8 deletions langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..errors import Warnings, Errors
import glob
from pkg_resources import resource_filename
from langtest.logger import logger

COLUMN_MAPPER = {
"text-classification": {
Expand Down Expand Up @@ -551,14 +552,49 @@ def export_data(self, data: List[NERSample], output_path: str):
output_path (str):
path to save the data to
"""
otext = ""
temp_id = None
for i in data:
text, temp_id = Formatter.process(i, output_format="conll", temp_id=temp_id)
otext += text + "\n"

with open(output_path, "wb") as fwriter:
fwriter.write(bytes(otext, encoding="utf-8"))
if output_path.endswith(".conll"):
otext = ""
temp_id = None
for i in data:
text, temp_id = Formatter.process(
i, output_format="conll", temp_id=temp_id
)
otext += text + "\n"

with open(output_path, "wb") as fwriter:
fwriter.write(bytes(otext, encoding="utf-8"))

elif output_path.endswith(".json"):
import json
from .utils import process_document

logger.warn("Only for Gen AI Lab use")
logger.info("Converting NER sample to JSON format")

otext_list = []
temp_id = None
for i in data:
otext, temp_id = Formatter.process(
i, output_format="json", temp_id=temp_id
)
processed_text = process_document(otext)
# add test info
tem_dict = processed_text["data"]
tem_dict["test_type"] = i.test_type or "null"
tem_dict["category"] = i.category or "null"

processed_text["data"] = tem_dict
otext_list.append(processed_text)

# otext += text + "\n"
# if temp_id2 != temp_id:
# processed_text = process_document(otext)
# otext_list.append(processed_text)
# otext = ""
# temp_id = temp_id2

with open(output_path, "w") as fwriter:
json.dump(otext_list, fwriter)

def __token_validation(self, tokens: str) -> (bool, List[List[str]]): # type: ignore
"""Validates the tokens in a sentence.
Expand Down
7 changes: 7 additions & 0 deletions langtest/datahandler/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,13 @@ def to_conll(sample: NERSample, temp_id: int = None) -> Union[str, Tuple[str, st

return text, temp_id

@staticmethod
def to_json(sample: NERSample, temp_id: int = None) -> dict:
"""Converts a NERSample to a JSON string."""

text, temp_id = NEROutputFormatter.to_conll(sample, temp_id)
return text, temp_id


class QAFormatter(BaseFormatter):
def to_jsonl(sample: QASample, *args, **kwargs):
Expand Down
116 changes: 116 additions & 0 deletions langtest/datahandler/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from datetime import datetime


def get_results(tokens, labels, text):
current_entity = None
current_span = []
results = []
char_pos = 0 # Tracks the character position in the text

for i, (token, label) in enumerate(zip(tokens, labels)):
token_start = char_pos
token_end = token_start + len(token)
if label.startswith("B-"):
if current_entity:
results.append(
{
"value": {
"start": current_span[0],
"end": current_span[-1],
"text": text[current_span[0] : current_span[-1]],
"labels": [current_entity],
"confidence": 1,
},
"from_name": "label",
"to_name": "text",
"type": "labels",
}
)
current_entity = label[2:]
current_span = [token_start, token_end]
elif label.startswith("I-") and current_entity:
current_span[-1] = token_end
elif label == "O" and current_entity:
results.append(
{
"value": {
"start": current_span[0],
"end": current_span[-1],
"text": text[current_span[0] : current_span[-1]],
"labels": [current_entity],
"confidence": 1,
},
"from_name": "label",
"to_name": "text",
"type": "labels",
}
)
current_entity = None
current_span = []

# Move to the next character position (account for the space between tokens)
char_pos = (
token_end + 1
if i + 1 < len(tokens) and tokens[i + 1] not in [".", ",", "!", "?"]
else token_end
)

if current_entity:
results.append(
{
"value": {
"start": current_span[0],
"end": current_span[-1],
"text": text[current_span[0] : current_span[-1]],
"labels": [current_entity],
"confidence": 1,
},
"from_name": "label",
"to_name": "text",
"type": "labels",
}
)
return results


def process_document(doc):
tokens = []
labels = []

# replace the -DOCSTART- tag with a newline
doc = doc.replace("-DOCSTART-", "")

for line in doc.strip().split("\n"):
if line.strip():
parts = line.strip().split()
if len(parts) == 4:
token, _, _, label = parts
tokens.append(token)
labels.append(label)

text = ""
for _, token in enumerate(tokens):
if token in {".", ",", "!", "?"}:
text = text.rstrip() + token + " "
else:
text += token + " "

text = text.rstrip()

results = get_results(tokens, labels, text)
now = datetime.utcnow()
current_date = now.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
json_output = {
"created_ago": current_date,
"result": results,
"honeypot": True,
"lead_time": 10,
"confidence_range": [0, 1],
"submitted_at": current_date,
"updated_at": current_date,
"predictions": [],
"created_at": current_date,
"data": {"text": text},
}

return json_output

0 comments on commit 78cb31f

Please sign in to comment.