Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swag #6

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 101 additions & 9 deletions lilbert/lib/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_train_examples(self, data_dir):
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()

def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
Expand Down Expand Up @@ -76,7 +76,7 @@ def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

def get_labels(self):
"""See base class."""
return ["0", "1"]
Expand Down Expand Up @@ -109,7 +109,7 @@ def get_dev_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")

def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
Expand Down Expand Up @@ -156,11 +156,11 @@ def _create_examples(self, lines, set_type):
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples


class SST2Processor(DataProcessor):
"""Processor for the SST2 data set (GLUE version)."""

def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
Expand All @@ -186,7 +186,7 @@ def _create_examples(self, lines, set_type):
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples


def get_quora_df(filename):
with open(filename, "r", encoding='utf-8') as f:
rows = list(csv.reader(f, delimiter='\t', quotechar=None))
Expand All @@ -195,11 +195,11 @@ def get_quora_df(filename):
df = df[pd.notnull(df['is_duplicate'])]
df.columns = ['text_a', 'text_b', 'label']
return df


class QQPProcessor(DataProcessor):
"""Processor for the SST2 data set (GLUE version)."""

def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
Expand All @@ -225,3 +225,95 @@ def _create_examples(self, df, set_type):
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples


class SwagExample(object):
"""A single training/test example for the SWAG dataset."""

def __init__(self,
swag_id,
context_sentence,
start_ending,
ending_0,
ending_1,
ending_2,
ending_3,
label=None):
self.swag_id = swag_id
self.context_sentence = context_sentence
self.start_ending = start_ending
self.endings = [
ending_0,
ending_1,
ending_2,
ending_3,
]
self.label = label

def __str__(self):
return self.__repr__()

def __repr__(self):
l = [
"swag_id: {}".format(self.swag_id),
"context_sentence: {}".format(self.context_sentence),
"start_ending: {}".format(self.start_ending),
"ending_0: {}".format(self.endings[0]),
"ending_1: {}".format(self.endings[1]),
"ending_2: {}".format(self.endings[2]),
"ending_3: {}".format(self.endings[3]),
]

if self.label is not None:
l.append("label: {}".format(self.label))

return ", ".join(l)


class SWAGProcessor(DataProcessor):
"""Processor for the SWAG data set."""

def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_csv(os.path.join(data_dir, "train.csv"), True)
return self._create_examples(lines, True)

def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_csv(os.path.join(data_dir, "val.csv"), True)
return self._create_examples(lines, True)

def _read_csv(self, input_file, is_training):
with open(input_file, 'r', encoding='utf-8') as f:
reader = csv.reader(f)
lines = []
for line in reader:
if sys.version_info[0] == 2:
line = list(unicode(cell, 'utf-8') for cell in line)
lines.append(line)

if is_training and lines[0][-1] != 'label':
raise ValueError(
"For training, the input file must contain a label column."
)
return lines[1:]

def _create_examples(self, lines, is_training):
"""Creates examples for the training and dev sets."""
examples = []
for line in lines:
examples.append(
SwagExample(
swag_id=line[2],
context_sentence=line[4],
start_ending=line[5], # in the swag dataset, the
# common beginning of each
# choice is stored in "sent2".
ending_0=line[7],
ending_1=line[8],
ending_2=line[9],
ending_3=line[10],
label=int(line[11]) if is_training else None
)
)
return examples
93 changes: 82 additions & 11 deletions lilbert/lib/feature_processors.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
import numpy as np
from lib.data_processors import InputFeatures
from tqdm import tqdm


def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
def convert_examples_to_features(examples, tokenizer, params):
if params['task_name'] != 'swag':
return convert_examples_to_features_clf(examples, params['label_list'],
params['max_seq_length'], tokenizer)
else:
return convert_examples_to_features_swag(examples, params['max_seq_length'], tokenizer)


def convert_examples_to_features_clf(examples, label_list, max_seq_length, tokenizer):
"""Loads a data file into a list of `InputBatch`s."""

label_map = {label : i for i, label in enumerate(label_list)}
label_map = {label: i for i, label in enumerate(label_list)}

features = []
for (ex_index, example) in enumerate(examples):
for (ex_index, example) in tqdm(enumerate(examples), total=len(examples), desc='converting examples'):
tokens_a = tokenizer.tokenize(example.text_a)

tokens_b = None
Expand Down Expand Up @@ -66,10 +75,76 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer

label_id = label_map[example.label]
features.append(
InputFeatures(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id))
InputFeatures(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id))
return features


def convert_examples_to_features_swag(examples, max_seq_length, tokenizer):
"""Loads a data file into a list of `InputBatch`s."""

# Swag is a multiple choice task. To perform this task using Bert,
# we will use the formatting proposed in "Improving Language
# Understanding by Generative Pre-Training" and suggested by
# @jacobdevlin-google in this issue
# https://github.com/google-research/bert/issues/38.
#
# Each choice will correspond to a sample on which we run the
# inference. For a given Swag example, we will create the 4
# following inputs:
# - [CLS] context [SEP] choice_1 [SEP]
# - [CLS] context [SEP] choice_2 [SEP]
# - [CLS] context [SEP] choice_3 [SEP]
# - [CLS] context [SEP] choice_4 [SEP]
# The model will output a single value for each input. To get the
# final decision of the model, we will run a softmax over these 4
# outputs.
features = []
for example_index, example in tqdm(enumerate(examples), total=len(examples), desc='converting examples'):
# for example_index, example in enumerate(examples):
context_tokens = tokenizer.tokenize(example.context_sentence)
start_ending_tokens = tokenizer.tokenize(example.start_ending)

choices_input_ids, choices_input_mask, choices_segment_ids = [], [], []
for ending_index, ending in enumerate(example.endings):
# We create a copy of the context tokens in order to be
# able to shrink it according to ending_tokens
context_tokens_choice = context_tokens[:]
ending_tokens = start_ending_tokens + tokenizer.tokenize(ending)
# Modifies `context_tokens_choice` and `ending_tokens` in
# place so that the total length is less than the
# specified length. Account for [CLS], [SEP], [SEP] with
# "- 3"
_truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3)

tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"]
segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1)

input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)

# Zero-pad up to the sequence length.
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
input_mask += padding
segment_ids += padding

assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length

choices_input_ids.append(input_ids)
choices_input_mask.append(input_mask)
choices_segment_ids.append(segment_ids)

features.append(
InputFeatures(input_ids=choices_input_ids,
input_mask=choices_input_mask,
segment_ids=choices_segment_ids,
label_id=example.label)
)
return features


Expand All @@ -88,7 +163,3 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
tokens_a.pop()
else:
tokens_b.pop()

def accuracy(out, labels):
outputs = np.argmax(out, axis=1)
return np.sum(outputs == labels)
1 change: 1 addition & 0 deletions lilbert/lib/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
'mrpc': data_processors.MrpcProcessor,
'sst2': data_processors.SST2Processor,
'qqp': data_processors.QQPProcessor,
'swag': data_processors.SWAGProcessor,
}

num_labels = {
Expand Down
43 changes: 20 additions & 23 deletions lilbert/lib/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pytorch_pretrained_bert.modeling import BertConfig

from lib import feature_processors, metrics
from lib.bert import BertForSequenceClassification


def train(model, tokenizer, params,
Expand All @@ -20,7 +19,7 @@ def train(model, tokenizer, params,
random.seed(params['seed'])
np.random.seed(params['seed'])
torch.manual_seed(params['seed'])

train_steps_per_epoch = int(len(train_examples) / params['train_batch_size'])
num_train_optimization_steps = train_steps_per_epoch * params['num_train_epochs']

Expand All @@ -38,32 +37,31 @@ def train(model, tokenizer, params,
lr=params['learning_rate'],
warmup=params['warmup_proportion'],
t_total=num_train_optimization_steps)

global_step = 0
nb_tr_steps = 0
tr_loss = 0

train_features = feature_processors.convert_examples_to_features(
train_examples,
params['label_list'],
params['max_seq_length'],
tokenizer)
tokenizer,
params)
print("***** Running training *****")
print("Num examples:", len(train_examples))
print("Num examples:", len(train_examples))
print("Batch size: ", params['train_batch_size'])
print("Num steps: ", num_train_optimization_steps)
all_input_ids = torch.tensor(
[f.input_ids for f in train_features],
dtype=torch.long)
dtype=torch.long)
all_input_mask = torch.tensor(
[f.input_mask for f in train_features],
dtype=torch.long)
dtype=torch.long)
all_segment_ids = torch.tensor(
[f.segment_ids for f in train_features],
dtype=torch.long)
dtype=torch.long)
all_label_ids = torch.tensor(
[f.label_id for f in train_features],
dtype=torch.long)
dtype=torch.long)
train_data = TensorDataset(all_input_ids,
all_input_mask,
all_segment_ids,
Expand Down Expand Up @@ -117,29 +115,28 @@ def train(model, tokenizer, params,
'train_loss': tr_loss / nb_tr_steps,
'train_global_step': global_step,
}

return model, train_result


def predict(model, tokenizer, params, valid_examples):
random.seed(params['seed'])
np.random.seed(params['seed'])
torch.manual_seed(params['seed'])

eval_features = feature_processors.convert_examples_to_features(
valid_examples,
params['label_list'],
params['max_seq_length'],
tokenizer)
valid_examples,
tokenizer,
params)
all_input_ids = torch.tensor(
[f.input_ids for f in eval_features],
dtype=torch.long)
dtype=torch.long)
all_input_mask = torch.tensor(
[f.input_mask for f in eval_features],
dtype=torch.long)
dtype=torch.long)
all_segment_ids = torch.tensor(
[f.segment_ids for f in eval_features],
dtype=torch.long)
dtype=torch.long)
eval_data = TensorDataset(all_input_ids,
all_input_mask,
all_segment_ids)
Expand Down Expand Up @@ -167,9 +164,9 @@ def evaluate(model, tokenizer, params, valid_examples):
print("***** Running evaluation *****")
print("Num examples: ", len(valid_examples))
print("Batch size: ", params['eval_batch_size'])

prob_preds = predict(model, tokenizer, params, valid_examples)
true_labels = np.array([int(example.label)
true_labels = np.array([int(example.label)
for i, example in enumerate(valid_examples)])
result = {
'eval_loss': metrics.log_loss(true_labels, prob_preds),
Expand Down
Loading