-
Notifications
You must be signed in to change notification settings - Fork 352
/
train.py
211 lines (159 loc) · 7.91 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import time
import tensorflow as tf
from model import evaluate
from model import srgan
from tensorflow.keras.applications.vgg19 import preprocess_input
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.losses import MeanAbsoluteError
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.metrics import Mean
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay
class Trainer:
def __init__(self,
model,
loss,
learning_rate,
checkpoint_dir='./ckpt/edsr'):
self.now = None
self.loss = loss
self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
psnr=tf.Variable(-1.0),
optimizer=Adam(learning_rate),
model=model)
self.checkpoint_manager = tf.train.CheckpointManager(checkpoint=self.checkpoint,
directory=checkpoint_dir,
max_to_keep=3)
self.restore()
@property
def model(self):
return self.checkpoint.model
def train(self, train_dataset, valid_dataset, steps, evaluate_every=1000, save_best_only=False):
loss_mean = Mean()
ckpt_mgr = self.checkpoint_manager
ckpt = self.checkpoint
self.now = time.perf_counter()
for lr, hr in train_dataset.take(steps - ckpt.step.numpy()):
ckpt.step.assign_add(1)
step = ckpt.step.numpy()
loss = self.train_step(lr, hr)
loss_mean(loss)
if step % evaluate_every == 0:
loss_value = loss_mean.result()
loss_mean.reset_states()
# Compute PSNR on validation dataset
psnr_value = self.evaluate(valid_dataset)
duration = time.perf_counter() - self.now
print(f'{step}/{steps}: loss = {loss_value.numpy():.3f}, PSNR = {psnr_value.numpy():3f} ({duration:.2f}s)')
if save_best_only and psnr_value <= ckpt.psnr:
self.now = time.perf_counter()
# skip saving checkpoint, no PSNR improvement
continue
ckpt.psnr = psnr_value
ckpt_mgr.save()
self.now = time.perf_counter()
@tf.function
def train_step(self, lr, hr):
with tf.GradientTape() as tape:
lr = tf.cast(lr, tf.float32)
hr = tf.cast(hr, tf.float32)
sr = self.checkpoint.model(lr, training=True)
loss_value = self.loss(hr, sr)
gradients = tape.gradient(loss_value, self.checkpoint.model.trainable_variables)
self.checkpoint.optimizer.apply_gradients(zip(gradients, self.checkpoint.model.trainable_variables))
return loss_value
def evaluate(self, dataset):
return evaluate(self.checkpoint.model, dataset)
def restore(self):
if self.checkpoint_manager.latest_checkpoint:
self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
print(f'Model restored from checkpoint at step {self.checkpoint.step.numpy()}.')
class EdsrTrainer(Trainer):
def __init__(self,
model,
checkpoint_dir,
learning_rate=PiecewiseConstantDecay(boundaries=[200000], values=[1e-4, 5e-5])):
super().__init__(model, loss=MeanAbsoluteError(), learning_rate=learning_rate, checkpoint_dir=checkpoint_dir)
def train(self, train_dataset, valid_dataset, steps=300000, evaluate_every=1000, save_best_only=True):
super().train(train_dataset, valid_dataset, steps, evaluate_every, save_best_only)
class WdsrTrainer(Trainer):
def __init__(self,
model,
checkpoint_dir,
learning_rate=PiecewiseConstantDecay(boundaries=[200000], values=[1e-3, 5e-4])):
super().__init__(model, loss=MeanAbsoluteError(), learning_rate=learning_rate, checkpoint_dir=checkpoint_dir)
def train(self, train_dataset, valid_dataset, steps=300000, evaluate_every=1000, save_best_only=True):
super().train(train_dataset, valid_dataset, steps, evaluate_every, save_best_only)
class SrganGeneratorTrainer(Trainer):
def __init__(self,
model,
checkpoint_dir,
learning_rate=1e-4):
super().__init__(model, loss=MeanSquaredError(), learning_rate=learning_rate, checkpoint_dir=checkpoint_dir)
def train(self, train_dataset, valid_dataset, steps=1000000, evaluate_every=1000, save_best_only=True):
super().train(train_dataset, valid_dataset, steps, evaluate_every, save_best_only)
class SrganTrainer:
#
# TODO: model and optimizer checkpoints
#
def __init__(self,
generator,
discriminator,
content_loss='VGG54',
learning_rate=PiecewiseConstantDecay(boundaries=[100000], values=[1e-4, 1e-5])):
if content_loss == 'VGG22':
self.vgg = srgan.vgg_22()
elif content_loss == 'VGG54':
self.vgg = srgan.vgg_54()
else:
raise ValueError("content_loss must be either 'VGG22' or 'VGG54'")
self.content_loss = content_loss
self.generator = generator
self.discriminator = discriminator
self.generator_optimizer = Adam(learning_rate=learning_rate)
self.discriminator_optimizer = Adam(learning_rate=learning_rate)
self.binary_cross_entropy = BinaryCrossentropy(from_logits=False)
self.mean_squared_error = MeanSquaredError()
def train(self, train_dataset, steps=200000):
pls_metric = Mean()
dls_metric = Mean()
step = 0
for lr, hr in train_dataset.take(steps):
step += 1
pl, dl = self.train_step(lr, hr)
pls_metric(pl)
dls_metric(dl)
if step % 50 == 0:
print(f'{step}/{steps}, perceptual loss = {pls_metric.result():.4f}, discriminator loss = {dls_metric.result():.4f}')
pls_metric.reset_states()
dls_metric.reset_states()
@tf.function
def train_step(self, lr, hr):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
lr = tf.cast(lr, tf.float32)
hr = tf.cast(hr, tf.float32)
sr = self.generator(lr, training=True)
hr_output = self.discriminator(hr, training=True)
sr_output = self.discriminator(sr, training=True)
con_loss = self._content_loss(hr, sr)
gen_loss = self._generator_loss(sr_output)
perc_loss = con_loss + 0.001 * gen_loss
disc_loss = self._discriminator_loss(hr_output, sr_output)
gradients_of_generator = gen_tape.gradient(perc_loss, self.generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
self.generator_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
self.discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))
return perc_loss, disc_loss
@tf.function
def _content_loss(self, hr, sr):
sr = preprocess_input(sr)
hr = preprocess_input(hr)
sr_features = self.vgg(sr) / 12.75
hr_features = self.vgg(hr) / 12.75
return self.mean_squared_error(hr_features, sr_features)
def _generator_loss(self, sr_out):
return self.binary_cross_entropy(tf.ones_like(sr_out), sr_out)
def _discriminator_loss(self, hr_out, sr_out):
hr_loss = self.binary_cross_entropy(tf.ones_like(hr_out), hr_out)
sr_loss = self.binary_cross_entropy(tf.zeros_like(sr_out), sr_out)
return hr_loss + sr_loss