forked from AlexYangLi/ALA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_aspect.py
106 lines (76 loc) · 3.53 KB
/
train_aspect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
"""
@author: alexyang
@contact: alex.yang0326@gmail.com
@file: train_aspect.py
@time: 2018/4/20 22:12
@desc: train aspect classification models
"""
import os
import logging
import numpy as np
import pickle
import tensorflow as tf
import utils
from read_data import read_data_for_aspect
from model import Classifier
flags = tf.app.flags
# common hyper-parameter
flags.DEFINE_integer('embedding_dim', 300, 'word embedding dimension')
flags.DEFINE_integer('n_epoch', 50, 'max epoch to train')
flags.DEFINE_integer('batch_size', 64, 'batch size')
flags.DEFINE_integer('early_stopping_step', 3, "if loss doesn't descend in 3 epochs, stop training")
flags.DEFINE_float('stddev', 0.01, 'weight initialization stddev')
flags.DEFINE_float('l2_reg', 0.001, 'l2 regularization')
flags.DEFINE_boolean('show', True, 'print train progress')
flags.DEFINE_boolean('embed_trainable', True, 'whether word embeddings are trainable')
flags.DEFINE_string('data', './data/train.csv', 'data set file path')
flags.DEFINE_string("vector_file", "./data/embeddings_300_dim.pkl", "pre-trained word vectors file path")
# hyper-parameter for aspect classification model
flags.DEFINE_string('classifier_type', 'lstm', 'type of classification model: lstm, cnn')
FLAGS = flags.FLAGS
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt='%a, %d %b %Y %H:%M:%S', filename='./log/train.log', filemode='a')
def train_model(data, word_embeddings):
# 10-fold cross validation
n_fold = 10
fold_size = int(len(data[0]) / n_fold)
loss_list, acc_list = [], []
for i in range(1, n_fold + 1):
FLAGS.train_time = i
train_data, valid_data = utils.split_train_valid(data, i, fold_size)
graph = tf.Graph()
with tf.Session(graph=graph)as sess:
model = Classifier(FLAGS, sess)
model.build_model()
loss, acc = model.run(train_data, valid_data, word_embeddings)
loss_list.append(loss)
acc_list.append(acc)
avg_loss = np.mean(loss_list)
avg_acc = np.mean(acc_list)
print("10fold_loss&acc:", avg_loss, avg_acc)
print('10fold_std_loss&acc:', np.std(loss_list), np.std(acc_list))
logging.debug('10fold_loss: ' + str(avg_loss) + '\t10fold_acc :' + str(avg_acc))
logging.debug('10fold_loss_std: ' + str(np.std(loss_list)) + '\t10fold_acc_std :' + str(np.std(acc_list)))
def main(_):
pre_trained_vectors = utils.get_gensim_vectors(FLAGS.vector_file)
data, word_embeddings, word2idx, max_context_len, onehot_mapping = read_data(FLAGS.data,
pre_trained_vectors,
FLAGS.embedding_dim)
FLAGS.max_len = max_context_len
FLAGS.n_word = word_embeddings.shape[0]
FLAGS.model_path = './save_model'
FLAGS.model_name = 'm'
FLAGS.n_class = data[-1].shape[1]
if not os.path.exists(FLAGS.model_path):
os.mkdir(FLAGS.model_path)
print('unique words embedding: ', word_embeddings.shape)
print('max sentence len: ', max_context_len)
print('n_class : ', FLAGS.n_class)
train_model(data, word_embeddings)
save_data = {'onehot_mapping': onehot_mapping, 'word2idx': word2idx, 'max_context_len': max_context_len}
with open(os.path.join(FLAGS.model_path, 'save_data'), 'wb')as writer:
pickle.dump(save_data, writer)
if __name__ == '__main__':
tf.app.run()