-
Notifications
You must be signed in to change notification settings - Fork 3
/
model.py
154 lines (127 loc) · 6.48 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import tensorflow as tf
from tensorflow.python.ops import rnn_cell
from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq
import numpy as np
import constants as c
from utils import unkify
class model:
def __inti__(self, sess, vocab, batch_size, seq_len, cell_size, num_layers, test=False):
"""
Initializes an Model.
@param sess: The TensorFlow Session.
@param vocab: A list of all the words on which the model will be trained.
@param batch_size: The batch size for training.
@param seq_len: The sequence length. (The number of words in each element of the batch.)
@param cell_size: The size of the hidden layers in the cells.
(also the size of the word embeddings.)
@param num_layers: The number of layers in the network.
@param test: Whether to test or train the model. Default = False.
"""
self.sess = sess
self.vocab = vocab
self.vocab_size = len(self.vocab)
self.batch_size = batch_size
self.seq_len = seq_len
self.cell_size = cell_size
self.num_layers = num_layers
self.build_graph(test)
def build_graph(self, test):
"""
Builds an graph in TensorFlow.
"""
if test:
self.batch_size = 1
self.seq_len = 1
##
# Cells
##
lstm_cell = rnn_cell.BasicLSTMCell(self.cell_size)
self.cell = rnn_cell.MultiRNNCell([lstm_cell] * self.num_layers)
##
# Data
##
# inputs and targets are 2D tensors of shape
self.inputs = tf.placeholder(tf.int32, [self.batch_size, self.seq_len])
self.targets = tf.placeholder(tf.int32, [self.batch_size, self.seq_len])
self.initial_state = self.cell.zero_state(self.batch_size, tf.float32)
##
# Variables
##
with tf.variable_scope('lstm_vars'):
self.ws = tf.get_variable('ws', [self.cell_size, self.vocab_size])
self.bs = tf.get_variable('bs', [self.vocab_size]) # TODO: initializer?
with tf.device('/cpu:0'): # put on CPU to parallelize for faster training/
self.embeddings = tf.get_variable('embeddings', [self.vocab_size, self.cell_size])
# get embeddings for all input words
input_embeddings = tf.nn.embedding_lookup(self.embeddings, self.inputs)
# The split splits this tensor into a seq_len long list of 3D tensors of shape
# [batch_size, 1, rnn_size]. The squeeze removes the 1 dimension from the 1st axis
# of each tensor
inputs_split = tf.split(input_embeddings, self.seq_len, 1)
inputs_split = [tf.squeeze(input_, [1]) for input_ in inputs_split]
def loop(prev, _):
prev = tf.matmul(prev, self.ws) + self.bs
prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
return tf.nn.embedding_lookup(self.embeddings, prev_symbol)
lstm_outputs_split, self.final_state = seq2seq.rnn_decoder(inputs_split,
self.initial_state,
self.cell,
loop_function=loop if test else None,
scope='lstm_vars')
lstm_outputs = tf.reshape(tf.concat(lstm_outputs_split, 1), [-1, self.cell_size])
logits = tf.matmul(lstm_outputs, self.ws) + self.bs
self.probs = tf.nn.softmax(logits)
##
# Train
##
total_loss = seq2seq.sequence_loss_by_example([logits],
[tf.reshape(self.targets, [-1])],
[tf.ones([self.batch_size * self.seq_len])],
self.vocab_size)
self.loss = tf.reduce_sum(total_loss) / self.batch_size / self.seq_len
self.global_step = tf.Variable(0, trainable=False, name='global_step')
self.optimizer = tf.train.AdamOptimizer(learning_rate=c.L_RATE, name='optimizer')
self.train_op = self.optimizer.minimize(self.loss,
global_step=self.global_step,
name='train_op')
def generate(self, num_out=200, prime=None, sample=True):
"""
Generate a sequence of text from the trained model.
@param num_out: The length of the sequence to generate, in num words.
@param prime: The priming sequence for generation. If None, pick a random word from the
vocabulary as prime.
@param sample: Whether to probabalistically sample the next word, rather than take the word
of max probability.
"""
state = self.sess.run(self.cell.zero_state(1, tf.float32))
# if no prime supplied, get a random word. Otherwise, translate all words in prime that
# aren't in dictionary to '*UNK*'
if prime is None:
prime = np.random.choice(self.vocab)
else:
prime = unkify(prime, self.vocab)
# prime the model state
for word in prime.split():
print word
last_word_i = self.vocab.index(word)
input_i = np.array([[last_word_i]])
feed_dict = {self.inputs: input_i, self.initial_state: state}
state = self.sess.run(self.final_state, feed_dict=feed_dict)
# generate the sequence
gen_seq = prime
for i in xrange(num_out):
# generate word probabilities
input_i = np.array([[last_word_i]]) #TODO: use dictionary?
feed_dict = {self.inputs: input_i, self.initial_state: state}
probs, state = self.sess.run([self.probs, self.final_state], feed_dict=feed_dict)
probs = probs[0]
# select index of new word
if sample:
gen_word_i = np.random.choice(np.arange(len(probs)), p=probs)
else:
gen_word_i = np.argmax(probs)
# append new word to the generated sequence
gen_word = self.vocab[gen_word_i]
gen_seq += ' ' + gen_word
last_word_i = gen_word_i
return gen_seq