-
Notifications
You must be signed in to change notification settings - Fork 1
/
ae_variational.py
111 lines (82 loc) · 4.67 KB
/
ae_variational.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import tensorflow as tf
import time
from tensorflow.examples.tutorials.mnist import input_data
import ae_tools as tools
class VariationalAutoEncoder:
# net
def __init__(self, n_input, n_hidden, optimizer=tf.train.AdamOptimizer(learning_rate=0.001)):
self.n_input = n_input
self.n_hidden = n_hidden
self.weights = self._initialize_weights()
# model
self.x = tf.placeholder(dtype=tf.float32, shape=[None, self.n_input])
self.z_mean = tf.add(tf.matmul(self.x, self.weights["w1"]), self.weights["b1"])
self.z_log_var = tf.add(tf.matmul(self.x, self.weights["log_var_w1"]), self.weights["log_var_b1"])
# reparameterization trick: 高斯分布标准化
eps = tf.random_normal(tf.stack([tf.shape(self.x)[0], self.n_hidden]), 0, 1, dtype=tf.float32)
self.hidden = tf.add(self.z_mean, tf.multiply(tf.sqrt(tf.exp(self.z_log_var)), eps))
self.output = tf.add(tf.matmul(self.hidden, self.weights["w2"]), self.weights["b2"])
# cost
self.reconstr_loss = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(self.output, self.x), 2.0))
self.latent_loss = - 0.5 * tf.reduce_sum(1 + self.z_log_var - tf.square(self.z_mean) - tf.exp(self.z_log_var), 1)
self.cost = tf.reduce_mean(self.reconstr_loss + self.latent_loss)
self.optimizer = optimizer.minimize(self.cost)
# sess
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
# all weights
def _initialize_weights(self):
all_weights = dict()
all_weights["w1"] = tf.get_variable("w1", shape=[self.n_input, self.n_hidden], initializer=tf.contrib.layers.xavier_initializer())
all_weights["b1"] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32))
all_weights["log_var_w1"] = tf.get_variable("log_var_w1", shape=[self.n_input, self.n_hidden], initializer=tf.contrib.layers.xavier_initializer())
all_weights["log_var_b1"] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32))
all_weights["w2"] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype=tf.float32))
all_weights["b2"] = tf.Variable(tf.zeros([self.n_input]), dtype=tf.float32)
return all_weights
# train
def partial_fit(self, X):
cost, _, loss_1, loss_2 = self.sess.run([self.cost, self.optimizer, self.reconstr_loss, self.latent_loss], feed_dict={self.x: X})
return cost
def calculate_total_cost(self, X):
return self.sess.run(self.cost, feed_dict={self.x: X})
# hidden -> output
def generate(self, hidden=None, batch_size=10):
if hidden is None:
hidden = self.sess.run(tf.random_normal([batch_size, self.n_hidden]))
return self.sess.run(self.output, feed_dict={self.hidden: hidden})
# input -> output
def output_result(self, X):
return self.sess.run(self.output, feed_dict={self.x: X})
pass
class Runner:
def __init__(self, autoencoder):
self.autoencoder = autoencoder
self.mnist = input_data.read_data_sets("data", one_hot=True)
self.x_train, self.x_test = tools.min_max_scale(self.mnist.train.images, self.mnist.test.images)
self.train_number = self.mnist.train.num_examples
def train(self, train_epochs=2000, batch_size=64, display_step=1):
for epoch in range(train_epochs):
avg_cost = 0.
total_batch = int(self.train_number) // batch_size
for i in range(total_batch):
batch_xs = tools.get_random_block_from_data(self.x_train, batch_size)
cost = self.autoencoder.partial_fit(batch_xs)
avg_cost += cost / self.train_number * batch_size
if epoch % display_step == 0:
self.save_result(file_name="result-{}-{}-{}-{}".format(epoch, self.autoencoder.n_input, self.autoencoder.n_hidden, avg_cost))
print(time.strftime("%H:%M:%S", time.localtime()), "Epoch:{}".format(epoch + 1), "cost={:.9f}".format(avg_cost))
print(time.strftime("%H:%M:%S", time.localtime()), "Total cost: {}".format(self.autoencoder.calculate_total_cost(self.mnist.test.images)))
pass
def save_result(self, file_name, n_show=10):
# 显示编码结果和解码后结果
images = tools.get_random_block_from_data(self.x_test, n_show, fixed=True)
decode = self.autoencoder.output_result(images)
# 对比原始图片重建图片
tools.gaussian_save_result(images, decode, decode, save_path="result/ae-variational2/{}.jpg".format(file_name))
pass
pass
if __name__ == '__main__':
runner = Runner(autoencoder=VariationalAutoEncoder(n_input=784, n_hidden=200))
runner.train()
pass