Skip to content

Commit

Permalink
Merge pull request #60 from Maitreyapatel/sa_augmentations
Browse files Browse the repository at this point in the history
Augmentations for Sentiment Analysis
  • Loading branch information
Maitreyapatel authored Apr 4, 2023
2 parents 866f4a3 + 4b9b25c commit 3885be8
Show file tree
Hide file tree
Showing 8 changed files with 899 additions and 713 deletions.
1,452 changes: 740 additions & 712 deletions data/parrot_sentiment140.csv

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions reliability_checklist/augmentation/augments.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
from reliability_checklist.augmentation.mnli.num_word import num_word_augmentation
from reliability_checklist.augmentation.mnli.rand_sent import rand_sentence_augmentation
from reliability_checklist.augmentation.mnli.swap_ant import swap_ant_augmentation
from reliability_checklist.augmentation.sentiment_analysis.back_translate import (
back_translate_augmentation,
)
from reliability_checklist.augmentation.sentiment_analysis.double_denial import (
double_denial_augmentation,
)


class Augmentation:
Expand Down Expand Up @@ -84,6 +90,24 @@ def augment(self):
self.dataset = self.augmenter.infer(self.dataset)


class back_translate_aug(Augmentation):
def __init__(self, __name__="BACK_TRANS", dataset=None, cols=None):
super().__init__(__name__, dataset)
self.augmenter = back_translate_augmentation(cols)

def augment(self):
self.dataset = self.augmenter.infer(self.dataset)


class double_denial_aug(Augmentation):
def __init__(self, __name__="DOUBLE_DENIAL", dataset=None, cols=None):
super().__init__(__name__, dataset)
self.augmenter = double_denial_augmentation(cols)

def augment(self):
self.dataset = self.augmenter.infer(self.dataset)


class parrot_paraphraser(Augmentation):
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pandas as pd
from datasets import ClassLabel, Dataset
from tqdm import tqdm
from transformers import MarianMTModel, MarianTokenizer


class back_translate_augmentation:
def __init__(self, cols=None):

self.cols = cols
self.model_translate_name = "Helsinki-NLP/opus-mt-en-roa"
self.model_back_translate_name = "Helsinki-NLP/opus-mt-roa-en"

def download(self, model_name):
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
return model, tokenizer

def infer(self, dataset):

model_translate, model_translate_tokenizer = self.download(self.model_translate_name)
model_back_translate, model_back_translate_tokenizer = self.download(
self.model_back_translate_name
)

datacols = list(dataset.features.keys()) + ["mapping"]
new_dataset = {k: [] for k in datacols}

for i in tqdm(range(len(dataset))):
src_text = [">>fra<< " + dataset[i]["text"]]
src_translated = model_translate.generate(
**model_translate_tokenizer(src_text, return_tensors="pt", padding=True)
)

tgt_text = [
model_translate_tokenizer.decode(t, skip_special_tokens=True)
for t in src_translated
][0]
tgt_text = [">>eng<< " + tgt_text]
tgt_translated = model_back_translate.generate(
**model_back_translate_tokenizer(tgt_text, return_tensors="pt", padding=True)
)

back_translated_text = [
model_back_translate_tokenizer.decode(t, skip_special_tokens=True)
for t in tgt_translated
][0]
if back_translated_text != dataset[i]["text"]:
new_dataset["text"] = back_translated_text
new_dataset["label"].append(dataset["label"][i])
new_dataset["mapping"].append(i)
for k in datacols:
if k not in ["label", "mapping"] + self.cols:
new_dataset[k].append(dataset[k][i])

new_dataset = pd.DataFrame(new_dataset)
return Dataset.from_pandas(new_dataset).cast_column("label", dataset.features["label"])
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pandas as pd
from datasets import ClassLabel, Dataset
from tqdm import tqdm
from transformers import MarianMTModel, MarianTokenizer


class double_denial_augmentation:
def __init__(self, cols=None):
self.cols = cols
self.polarity_dict = {
"poor": "not good",
"bad": "not great",
"lame": "not interesting",
"awful": "not awesome",
"great": "not bad",
"good": "not poor",
"applause": "not discourage",
"recommend": "don't prevent",
"best": "not worst",
"encourage": "don't discourage",
"entertain": "don't disapprove",
"wonderfully": "not poorly",
"love": "don't hate",
"interesting": "not uninteresting",
"interested": "not ignorant",
"glad": "not reluctant",
"positive": "not negative",
"perfect": "not imperfect",
"entertaining": "not uninteresting",
"moved": "not moved",
"like": "don't refuse",
"worth": "not undeserving",
"better": "not worse",
"funny": "not uninteresting",
"awesome": "not ugly",
"impressed": "not impressed",
}

def infer(self, dataset):
datacols = list(dataset.features.keys()) + ["mapping"]
new_dataset = {k: [] for k in datacols}

for i in tqdm(range(len(dataset))):
flag = False
tokens = dataset[i]["text"].split()
augmented_string = ""
for each_token in tokens:
if each_token in self.polarity_dict:
augmented_string += self.polarity_dict[each_token]
flag = True
else:
augmented_string += each_token
augmented_string += " "

if flag:
new_dataset["text"].append(augmented_string)
new_dataset["label"].append(dataset["label"][i])
new_dataset["mapping"].append(i)

for k in datacols:
if k not in ["label", "mapping"] + self.cols:
new_dataset[k].append(dataset[k][i])

new_dataset = pd.DataFrame(new_dataset)
return Dataset.from_pandas(new_dataset).cast_column("label", dataset.features["label"])
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
back_translate:
_target_: reliability_checklist.augmentation.augments.back_translate_aug
_partial_: true
__name__: "BACK_TRANS"
cols: ${datamodule.dataset_specific_args.cols}
5 changes: 5 additions & 0 deletions reliability_checklist/configs/augmentation/double_denial.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
double_denial:
_target_: reliability_checklist.augmentation.augments.double_denial_aug
_partial_: true
__name__: "DOUBLE_DENIAL"
cols: ${datamodule.dataset_specific_args.cols}
2 changes: 2 additions & 0 deletions reliability_checklist/configs/augmentation/sentiment.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
defaults:
- default.yaml
- parrot.yaml
- back_translate.yaml
- double_denial.yaml
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ hydra-colorlog==1.2.0
hydra-optuna-sweeper==1.2.0

# --------- loggers --------- #
# wandb
wandb
# neptune-client
# mlflow
# comet-ml
Expand Down

0 comments on commit 3885be8

Please sign in to comment.