-
Notifications
You must be signed in to change notification settings - Fork 854
/
WGAN.py
256 lines (198 loc) · 10.5 KB
/
WGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
#-*- coding: utf-8 -*-
from __future__ import division
import os
import time
import tensorflow as tf
import numpy as np
from ops import *
from utils import *
class WGAN(object):
model_name = "WGAN" # name for checkpoint
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
self.sess = sess
self.dataset_name = dataset_name
self.checkpoint_dir = checkpoint_dir
self.result_dir = result_dir
self.log_dir = log_dir
self.epoch = epoch
self.batch_size = batch_size
if dataset_name == 'mnist' or dataset_name == 'fashion-mnist':
# parameters
self.input_height = 28
self.input_width = 28
self.output_height = 28
self.output_width = 28
self.z_dim = z_dim # dimension of noise-vector
self.c_dim = 1
# WGAN parameter
self.disc_iters = 1 # The number of critic iterations for one-step of generator
# train
self.learning_rate = 0.0002
self.beta1 = 0.5
# test
self.sample_num = 64 # number of generated images to be saved
# load mnist
self.data_X, self.data_y = load_mnist(self.dataset_name)
# get number of batches for a single epoch
self.num_batches = len(self.data_X) // self.batch_size
else:
raise NotImplementedError
def discriminator(self, x, is_training=True, reuse=False):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
with tf.variable_scope("discriminator", reuse=reuse):
net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1'))
net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2'))
net = tf.reshape(net, [self.batch_size, -1])
net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3'))
out_logit = linear(net, 1, scope='d_fc4')
out = tf.nn.sigmoid(out_logit)
return out, out_logit, net
def generator(self, z, is_training=True, reuse=False):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
with tf.variable_scope("generator", reuse=reuse):
net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1'))
net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2'))
net = tf.reshape(net, [self.batch_size, 7, 7, 128])
net = tf.nn.relu(
bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training,
scope='g_bn3'))
out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4'))
return out
def build_model(self):
# some parameters
image_dims = [self.input_height, self.input_width, self.c_dim]
bs = self.batch_size
""" Graph Input """
# images
self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images')
# noises
self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')
""" Loss Function """
# output of D for real images
D_real, D_real_logits, _ = self.discriminator(self.inputs, is_training=True, reuse=False)
# output of D for fake images
G = self.generator(self.z, is_training=True, reuse=False)
D_fake, D_fake_logits, _ = self.discriminator(G, is_training=True, reuse=True)
# get loss for discriminator
d_loss_real = - tf.reduce_mean(D_real_logits)
d_loss_fake = tf.reduce_mean(D_fake_logits)
self.d_loss = d_loss_real + d_loss_fake
# get loss for generator
self.g_loss = - d_loss_fake
""" Training """
# divide trainable variables into a group for D and a group for G
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'd_' in var.name]
g_vars = [var for var in t_vars if 'g_' in var.name]
# optimizers
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
.minimize(self.d_loss, var_list=d_vars)
self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \
.minimize(self.g_loss, var_list=g_vars)
# weight clipping
self.clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in d_vars]
"""" Testing """
# for test
self.fake_images = self.generator(self.z, is_training=False, reuse=True)
""" Summary """
d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real)
d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake)
d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
# final summary operations
self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum])
self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum])
def train(self):
# initialize all variables
tf.global_variables_initializer().run()
# graph inputs for visualize training results
self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim))
# saver to save model
self.saver = tf.train.Saver()
# summary writer
self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)
# restore check-point if it exits
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
if could_load:
start_epoch = (int)(checkpoint_counter / self.num_batches)
start_batch_id = checkpoint_counter - start_epoch * self.num_batches
counter = checkpoint_counter
print(" [*] Load SUCCESS")
else:
start_epoch = 0
start_batch_id = 0
counter = 1
print(" [!] Load failed...")
# loop for epoch
start_time = time.time()
for epoch in range(start_epoch, self.epoch):
# get batch data
for idx in range(start_batch_id, self.num_batches):
batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size]
batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)
# update D network
_, _, summary_str, d_loss = self.sess.run([self.d_optim, self.clip_D, self.d_sum, self.d_loss],
feed_dict={self.inputs: batch_images, self.z: batch_z})
self.writer.add_summary(summary_str, counter)
# update G network
if (counter - 1) % self.disc_iters == 0:
_, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict={self.z: batch_z})
self.writer.add_summary(summary_str, counter)
# display training status
counter += 1
print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
% (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))
# save training results for every 300 steps
if np.mod(counter, 300) == 0:
samples = self.sess.run(self.fake_images,
feed_dict={self.z: self.sample_z})
tot_num_samples = min(self.sample_num, self.batch_size)
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
'./' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
epoch, idx))
# After an epoch, start_batch_id is set to zero
# non-zero value is only for the first epoch after loading pre-trained model
start_batch_id = 0
# save model
self.save(self.checkpoint_dir, counter)
# show temporal results
self.visualize_results(epoch)
# save model for final step
self.save(self.checkpoint_dir, counter)
def visualize_results(self, epoch):
tot_num_samples = min(self.sample_num, self.batch_size)
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
""" random condition, random noise """
z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
@property
def model_dir(self):
return "{}_{}_{}_{}".format(
self.model_name, self.dataset_name,
self.batch_size, self.z_dim)
def save(self, checkpoint_dir, step):
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)
def load(self, checkpoint_dir):
import re
print(" [*] Reading checkpoints...")
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
print(" [*] Success to read {}".format(ckpt_name))
return True, counter
else:
print(" [*] Failed to find a checkpoint")
return False, 0