diff --git a/lilbert/lib/data_processors.py b/lilbert/lib/data_processors.py index 28097c9..fec088f 100644 --- a/lilbert/lib/data_processors.py +++ b/lilbert/lib/data_processors.py @@ -46,7 +46,7 @@ def get_train_examples(self, data_dir): def get_dev_examples(self, data_dir): """Gets a collection of `InputExample`s for the dev set.""" raise NotImplementedError() - + def get_labels(self): """Gets the list of labels for this data set.""" raise NotImplementedError() @@ -76,7 +76,7 @@ def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") - + def get_labels(self): """See base class.""" return ["0", "1"] @@ -109,7 +109,7 @@ def get_dev_examples(self, data_dir): return self._create_examples( self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched") - + def get_labels(self): """See base class.""" return ["contradiction", "entailment", "neutral"] @@ -156,11 +156,11 @@ def _create_examples(self, lines, set_type): examples.append( InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) return examples - + class SST2Processor(DataProcessor): """Processor for the SST2 data set (GLUE version).""" - + def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( @@ -186,7 +186,7 @@ def _create_examples(self, lines, set_type): InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) return examples - + def get_quora_df(filename): with open(filename, "r", encoding='utf-8') as f: rows = list(csv.reader(f, delimiter='\t', quotechar=None)) @@ -195,11 +195,11 @@ def get_quora_df(filename): df = df[pd.notnull(df['is_duplicate'])] df.columns = ['text_a', 'text_b', 'label'] return df - - + + class QQPProcessor(DataProcessor): """Processor for the SST2 data set (GLUE version).""" - + def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( @@ -225,3 +225,95 @@ def _create_examples(self, df, set_type): examples.append( InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples + + +class SwagExample(object): + """A single training/test example for the SWAG dataset.""" + + def __init__(self, + swag_id, + context_sentence, + start_ending, + ending_0, + ending_1, + ending_2, + ending_3, + label=None): + self.swag_id = swag_id + self.context_sentence = context_sentence + self.start_ending = start_ending + self.endings = [ + ending_0, + ending_1, + ending_2, + ending_3, + ] + self.label = label + + def __str__(self): + return self.__repr__() + + def __repr__(self): + l = [ + "swag_id: {}".format(self.swag_id), + "context_sentence: {}".format(self.context_sentence), + "start_ending: {}".format(self.start_ending), + "ending_0: {}".format(self.endings[0]), + "ending_1: {}".format(self.endings[1]), + "ending_2: {}".format(self.endings[2]), + "ending_3: {}".format(self.endings[3]), + ] + + if self.label is not None: + l.append("label: {}".format(self.label)) + + return ", ".join(l) + + +class SWAGProcessor(DataProcessor): + """Processor for the SWAG data set.""" + + def get_train_examples(self, data_dir): + """See base class.""" + lines = self._read_csv(os.path.join(data_dir, "train.csv"), True) + return self._create_examples(lines, True) + + def get_dev_examples(self, data_dir): + """See base class.""" + lines = self._read_csv(os.path.join(data_dir, "val.csv"), True) + return self._create_examples(lines, True) + + def _read_csv(self, input_file, is_training): + with open(input_file, 'r', encoding='utf-8') as f: + reader = csv.reader(f) + lines = [] + for line in reader: + if sys.version_info[0] == 2: + line = list(unicode(cell, 'utf-8') for cell in line) + lines.append(line) + + if is_training and lines[0][-1] != 'label': + raise ValueError( + "For training, the input file must contain a label column." + ) + return lines[1:] + + def _create_examples(self, lines, is_training): + """Creates examples for the training and dev sets.""" + examples = [] + for line in lines: + examples.append( + SwagExample( + swag_id=line[2], + context_sentence=line[4], + start_ending=line[5], # in the swag dataset, the + # common beginning of each + # choice is stored in "sent2". + ending_0=line[7], + ending_1=line[8], + ending_2=line[9], + ending_3=line[10], + label=int(line[11]) if is_training else None + ) + ) + return examples diff --git a/lilbert/lib/feature_processors.py b/lilbert/lib/feature_processors.py index c00073c..9f357bb 100644 --- a/lilbert/lib/feature_processors.py +++ b/lilbert/lib/feature_processors.py @@ -1,14 +1,23 @@ import numpy as np from lib.data_processors import InputFeatures +from tqdm import tqdm -def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): +def convert_examples_to_features(examples, tokenizer, params): + if params['task_name'] != 'swag': + return convert_examples_to_features_clf(examples, params['label_list'], + params['max_seq_length'], tokenizer) + else: + return convert_examples_to_features_swag(examples, params['max_seq_length'], tokenizer) + + +def convert_examples_to_features_clf(examples, label_list, max_seq_length, tokenizer): """Loads a data file into a list of `InputBatch`s.""" - label_map = {label : i for i, label in enumerate(label_list)} + label_map = {label: i for i, label in enumerate(label_list)} features = [] - for (ex_index, example) in enumerate(examples): + for (ex_index, example) in tqdm(enumerate(examples), total=len(examples), desc='converting examples'): tokens_a = tokenizer.tokenize(example.text_a) tokens_b = None @@ -66,10 +75,76 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer label_id = label_map[example.label] features.append( - InputFeatures(input_ids=input_ids, - input_mask=input_mask, - segment_ids=segment_ids, - label_id=label_id)) + InputFeatures(input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids, + label_id=label_id)) + return features + + +def convert_examples_to_features_swag(examples, max_seq_length, tokenizer): + """Loads a data file into a list of `InputBatch`s.""" + + # Swag is a multiple choice task. To perform this task using Bert, + # we will use the formatting proposed in "Improving Language + # Understanding by Generative Pre-Training" and suggested by + # @jacobdevlin-google in this issue + # https://github.com/google-research/bert/issues/38. + # + # Each choice will correspond to a sample on which we run the + # inference. For a given Swag example, we will create the 4 + # following inputs: + # - [CLS] context [SEP] choice_1 [SEP] + # - [CLS] context [SEP] choice_2 [SEP] + # - [CLS] context [SEP] choice_3 [SEP] + # - [CLS] context [SEP] choice_4 [SEP] + # The model will output a single value for each input. To get the + # final decision of the model, we will run a softmax over these 4 + # outputs. + features = [] + for example_index, example in tqdm(enumerate(examples), total=len(examples), desc='converting examples'): + # for example_index, example in enumerate(examples): + context_tokens = tokenizer.tokenize(example.context_sentence) + start_ending_tokens = tokenizer.tokenize(example.start_ending) + + choices_input_ids, choices_input_mask, choices_segment_ids = [], [], [] + for ending_index, ending in enumerate(example.endings): + # We create a copy of the context tokens in order to be + # able to shrink it according to ending_tokens + context_tokens_choice = context_tokens[:] + ending_tokens = start_ending_tokens + tokenizer.tokenize(ending) + # Modifies `context_tokens_choice` and `ending_tokens` in + # place so that the total length is less than the + # specified length. Account for [CLS], [SEP], [SEP] with + # "- 3" + _truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3) + + tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"] + segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1) + + input_ids = tokenizer.convert_tokens_to_ids(tokens) + input_mask = [1] * len(input_ids) + + # Zero-pad up to the sequence length. + padding = [0] * (max_seq_length - len(input_ids)) + input_ids += padding + input_mask += padding + segment_ids += padding + + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + + choices_input_ids.append(input_ids) + choices_input_mask.append(input_mask) + choices_segment_ids.append(segment_ids) + + features.append( + InputFeatures(input_ids=choices_input_ids, + input_mask=choices_input_mask, + segment_ids=choices_segment_ids, + label_id=example.label) + ) return features @@ -88,7 +163,3 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): tokens_a.pop() else: tokens_b.pop() - -def accuracy(out, labels): - outputs = np.argmax(out, axis=1) - return np.sum(outputs == labels) diff --git a/lilbert/lib/tasks.py b/lilbert/lib/tasks.py index 3d7efa3..f33b7f3 100644 --- a/lilbert/lib/tasks.py +++ b/lilbert/lib/tasks.py @@ -6,6 +6,7 @@ 'mrpc': data_processors.MrpcProcessor, 'sst2': data_processors.SST2Processor, 'qqp': data_processors.QQPProcessor, + 'swag': data_processors.SWAGProcessor, } num_labels = { diff --git a/lilbert/lib/train_eval.py b/lilbert/lib/train_eval.py index 80fc772..e4765fc 100644 --- a/lilbert/lib/train_eval.py +++ b/lilbert/lib/train_eval.py @@ -9,7 +9,6 @@ from pytorch_pretrained_bert.modeling import BertConfig from lib import feature_processors, metrics -from lib.bert import BertForSequenceClassification def train(model, tokenizer, params, @@ -20,7 +19,7 @@ def train(model, tokenizer, params, random.seed(params['seed']) np.random.seed(params['seed']) torch.manual_seed(params['seed']) - + train_steps_per_epoch = int(len(train_examples) / params['train_batch_size']) num_train_optimization_steps = train_steps_per_epoch * params['num_train_epochs'] @@ -38,32 +37,31 @@ def train(model, tokenizer, params, lr=params['learning_rate'], warmup=params['warmup_proportion'], t_total=num_train_optimization_steps) - + global_step = 0 nb_tr_steps = 0 tr_loss = 0 - + train_features = feature_processors.convert_examples_to_features( train_examples, - params['label_list'], - params['max_seq_length'], - tokenizer) + tokenizer, + params) print("***** Running training *****") - print("Num examples:", len(train_examples)) + print("Num examples:", len(train_examples)) print("Batch size: ", params['train_batch_size']) print("Num steps: ", num_train_optimization_steps) all_input_ids = torch.tensor( [f.input_ids for f in train_features], - dtype=torch.long) + dtype=torch.long) all_input_mask = torch.tensor( [f.input_mask for f in train_features], - dtype=torch.long) + dtype=torch.long) all_segment_ids = torch.tensor( [f.segment_ids for f in train_features], - dtype=torch.long) + dtype=torch.long) all_label_ids = torch.tensor( [f.label_id for f in train_features], - dtype=torch.long) + dtype=torch.long) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, @@ -117,7 +115,7 @@ def train(model, tokenizer, params, 'train_loss': tr_loss / nb_tr_steps, 'train_global_step': global_step, } - + return model, train_result @@ -125,21 +123,20 @@ def predict(model, tokenizer, params, valid_examples): random.seed(params['seed']) np.random.seed(params['seed']) torch.manual_seed(params['seed']) - + eval_features = feature_processors.convert_examples_to_features( - valid_examples, - params['label_list'], - params['max_seq_length'], - tokenizer) + valid_examples, + tokenizer, + params) all_input_ids = torch.tensor( [f.input_ids for f in eval_features], - dtype=torch.long) + dtype=torch.long) all_input_mask = torch.tensor( [f.input_mask for f in eval_features], - dtype=torch.long) + dtype=torch.long) all_segment_ids = torch.tensor( [f.segment_ids for f in eval_features], - dtype=torch.long) + dtype=torch.long) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids) @@ -167,9 +164,9 @@ def evaluate(model, tokenizer, params, valid_examples): print("***** Running evaluation *****") print("Num examples: ", len(valid_examples)) print("Batch size: ", params['eval_batch_size']) - + prob_preds = predict(model, tokenizer, params, valid_examples) - true_labels = np.array([int(example.label) + true_labels = np.array([int(example.label) for i, example in enumerate(valid_examples)]) result = { 'eval_loss': metrics.log_loss(true_labels, prob_preds), diff --git a/lilbert/notebooks/train_bert_swag.ipynb b/lilbert/notebooks/train_bert_swag.ipynb new file mode 100644 index 0000000..d7fd280 --- /dev/null +++ b/lilbert/notebooks/train_bert_swag.ipynb @@ -0,0 +1,250 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Скачаем `SWAG` датасет" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!git clone https://github.com/rowanz/swagaf.git\n", + "!mv swagaf/data/ ../datasets/SWAG\n", + "!rm -fr swagaf" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n" + ] + } + ], + "source": [ + "import sys\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "sys.path.append('..')\n", + "\n", + "import numpy as np\n", + "import random\n", + "import torch\n", + "import os\n", + "from pytorch_pretrained_bert.tokenization import BertTokenizer\n", + "\n", + "from lib import data_processors, tasks\n", + "from pytorch_pretrained_bert import BertForMultipleChoice\n", + "from lib.train_eval import train, evaluate, predict" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "params = {\n", + " 'data_dir': '../datasets/SWAG',\n", + " 'output_dir': '../output',\n", + " 'cache_dir': '../model_cache',\n", + " 'task_name': 'swag',\n", + " 'bert_model': 'bert-base-uncased',\n", + " 'max_seq_length': 128,\n", + " 'train_batch_size': 12,\n", + " 'eval_batch_size': 8,\n", + " 'learning_rate': 2e-5,\n", + " 'warmup_proportion': 0.1,\n", + " 'num_train_epochs': 1,\n", + " 'seed': 1331,\n", + " 'device': torch.device(\n", + " 'cuda' if torch.cuda.is_available()\n", + " else 'cpu')\n", + "}\n", + "\n", + "random.seed(params['seed'])\n", + "np.random.seed(params['seed'])\n", + "torch.manual_seed(params['seed'])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "processor = tasks.processors[params['task_name']]()\n", + "tokenizer = BertTokenizer.from_pretrained(\n", + " params['bert_model'], do_lower_case=True)\n", + "\n", + "train_examples = processor.get_train_examples(params['data_dir'])\n", + "dev_examples = processor.get_dev_examples(params['data_dir'])\n", + "\n", + "model = BertForMultipleChoice.from_pretrained(\n", + " params['bert_model'],\n", + " cache_dir=params['cache_dir'], num_choices=4).to(params['device'])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "converting examples: 6%|▌ | 4280/73546 [00:05<01:23, 830.33it/s]\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mtrain_examples\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mvalid_examples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdev_examples\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m checkpoint_files=checkpoint_files)\n\u001b[0m", + "\u001b[0;32m~/projects/lilbert/lilbert/lib/train_eval.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(model, tokenizer, params, train_examples, valid_examples, checkpoint_files)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mtrain_examples\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m params)\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"***** Running training *****\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Num examples:\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_examples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/projects/lilbert/lilbert/lib/feature_processors.py\u001b[0m in \u001b[0;36mconvert_examples_to_features\u001b[0;34m(examples, tokenizer, params)\u001b[0m\n\u001b[1;32m 9\u001b[0m params['max_seq_length'], tokenizer)\n\u001b[1;32m 10\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mconvert_examples_to_features_swag\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexamples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'max_seq_length'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/projects/lilbert/lilbert/lib/feature_processors.py\u001b[0m in \u001b[0;36mconvert_examples_to_features_swag\u001b[0;34m(examples, max_seq_length, tokenizer)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;31m# able to shrink it according to ending_tokens\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0mcontext_tokens_choice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcontext_tokens\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 115\u001b[0;31m \u001b[0mending_tokens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstart_ending_tokens\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokenize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mending\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 116\u001b[0m \u001b[0;31m# Modifies `context_tokens_choice` and `ending_tokens` in\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0;31m# place so that the total length is less than the\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.6/site-packages/pytorch_pretrained_bert/tokenization.py\u001b[0m in \u001b[0;36mtokenize\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtokenize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0msplit_tokens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mtoken\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbasic_tokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokenize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0msub_token\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwordpiece_tokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokenize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtoken\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0msplit_tokens\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msub_token\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.6/site-packages/pytorch_pretrained_bert/tokenization.py\u001b[0m in \u001b[0;36mtokenize\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0;31m# characters in the vocabulary because Wikipedia does have some Chinese\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[0;31m# words in the English Wikipedia.).\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 180\u001b[0;31m \u001b[0mtext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tokenize_chinese_chars\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 181\u001b[0m \u001b[0morig_tokens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwhitespace_tokenize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 182\u001b[0m \u001b[0msplit_tokens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.6/site-packages/pytorch_pretrained_bert/tokenization.py\u001b[0m in \u001b[0;36m_tokenize_chinese_chars\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 227\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mchar\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 229\u001b[0;31m \u001b[0mcp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mord\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchar\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 230\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_chinese_char\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\" \"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "EPOCH_NUM = 1\n", + "\n", + "params['num_train_epochs'] = 1\n", + "checkpoint_files = {\n", + " 'config': 'bert_config.json',\n", + " 'model_weigths': 'model_{}_epoch_{}.pth'.format(\n", + " params['task_name'], EPOCH_NUM)\n", + "}\n", + "\n", + "model, result = train(model, tokenizer, params,\n", + " train_examples,\n", + " valid_examples=dev_examples,\n", + " checkpoint_files=checkpoint_files)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "converting examples: 0%| | 72/20006 [00:00<00:27, 719.02it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "***** Running evaluation *****\n", + "Num examples: 20006\n", + "Batch size: 8\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "converting examples: 100%|██████████| 20006/20006 [00:23<00:00, 869.80it/s]\n", + "Evaluating: 100%|██████████| 2501/2501 [01:02<00:00, 40.08it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "{'eval_loss': 1.3916393557536797,\n", + " 'eval_accuracy': 0.22358292512246325,\n", + " 'eval_f1_score': 0.21350037939898828}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.metrics import f1_score\n", + "from lib import metrics\n", + "\n", + "def f1_score_multiclass(true_labels, prob_preds):\n", + " pred_labels = np.argmax(prob_preds, axis=1)\n", + " return f1_score(true_labels, pred_labels, average='macro')\n", + "\n", + "print(\"***** Running evaluation *****\")\n", + "print(\"Num examples: \", len(dev_examples))\n", + "print(\"Batch size: \", params['eval_batch_size'])\n", + "\n", + "prob_preds = predict(model, tokenizer, params, dev_examples)\n", + "true_labels = np.array([int(example.label)\n", + " for i, example in enumerate(dev_examples)])\n", + "result = {\n", + " 'eval_loss': metrics.log_loss(true_labels, prob_preds),\n", + " 'eval_accuracy': metrics.accuracy(true_labels, prob_preds),\n", + " 'eval_f1_score': f1_score_multiclass(true_labels, prob_preds),\n", + "}\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}