-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
155 lines (137 loc) · 7.38 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os
import argparse
import torch
from src.loader import load_images, DataSampler
from src.utils import initialize_exp, bool_flag, attr_flag, check_attr
from src.model import AutoEncoder, LatentDiscriminator, PatchDiscriminator, Classifier
from src.training import Trainer
from src.evaluation import Evaluator
# parse parameters
parser = argparse.ArgumentParser(description='Images autoencoder')
parser.add_argument("--name", type=str, default="default",
help="Experiment name")
parser.add_argument("--img_sz", type=int, default=256,
help="Image sizes (images have to be squared)")
parser.add_argument("--img_fm", type=int, default=3,
help="Number of feature maps (1 for grayscale, 3 for RGB)")
parser.add_argument("--attr", type=attr_flag, default="Race.5",
help="Attributes to classify")
parser.add_argument("--instance_norm", type=bool_flag, default=False,
help="Use instance normalization instead of batch normalization")
parser.add_argument("--init_fm", type=int, default=32,
help="Number of initial filters in the encoder")
parser.add_argument("--max_fm", type=int, default=512,
help="Number maximum of filters in the autoencoder")
parser.add_argument("--n_layers", type=int, default=6,
help="Number of layers in the encoder / decoder")
parser.add_argument("--n_skip", type=int, default=0,
help="Number of skip connections")
parser.add_argument("--deconv_method", type=str, default="convtranspose",
help="Deconvolution method")
parser.add_argument("--hid_dim", type=int, default=512,
help="Last hidden layer dimension for discriminator / classifier")
parser.add_argument("--dec_dropout", type=float, default=0.,
help="Dropout in the decoder")
parser.add_argument("--lat_dis_dropout", type=float, default=0.3,
help="Dropout in the latent discriminator")
parser.add_argument("--n_lat_dis", type=int, default=1,
help="Number of latent discriminator training steps")
parser.add_argument("--n_ptc_dis", type=int, default=0,
help="Number of patch discriminator training steps")
parser.add_argument("--n_clf_dis", type=int, default=0,
help="Number of classifier discriminator training steps")
parser.add_argument("--smooth_label", type=float, default=0.2,
help="Smooth label for patch discriminator")
parser.add_argument("--lambda_ae", type=float, default=1,
help="Autoencoder loss coefficient")
parser.add_argument("--lambda_lat_dis", type=float, default=0.0001,
help="Latent discriminator loss feedback coefficient")
parser.add_argument("--lambda_ptc_dis", type=float, default=0,
help="Patch discriminator loss feedback coefficient")
parser.add_argument("--lambda_clf_dis", type=float, default=0,
help="Classifier discriminator loss feedback coefficient")
parser.add_argument("--lambda_schedule", type=float, default=500000,
help="Progressively increase discriminators' lambdas (0 to disable)")
parser.add_argument("--v_flip", type=bool_flag, default=False,
help="Random vertical flip for data augmentation")
parser.add_argument("--h_flip", type=bool_flag, default=True,
help="Random horizontal flip for data augmentation")
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size")
parser.add_argument("--ae_optimizer", type=str, default="adam,lr=0.0002",
help="Autoencoder optimizer (SGD / RMSprop / Adam, etc.)")
parser.add_argument("--dis_optimizer", type=str, default="adam,lr=0.0002",
help="Discriminator optimizer (SGD / RMSprop / Adam, etc.)")
parser.add_argument("--clip_grad_norm", type=float, default=5,
help="Clip gradient norms (0 to disable)")
parser.add_argument("--n_epochs", type=int, default=1000,
help="Total number of epochs")
parser.add_argument("--epoch_size", type=int, default=50000,
help="Number of samples per epoch")
parser.add_argument("--ae_reload", type=str, default="",
help="Reload a pretrained encoder")
parser.add_argument("--lat_dis_reload", type=str, default="",
help="Reload a pretrained latent discriminator")
parser.add_argument("--ptc_dis_reload", type=str, default="",
help="Reload a pretrained patch discriminator")
parser.add_argument("--clf_dis_reload", type=str, default="",
help="Reload a pretrained classifier discriminator")
parser.add_argument("--eval_clf", type=str, default="",
help="Load an external classifier for evaluation")
parser.add_argument("--debug", type=bool_flag, default=False,
help="Debug mode (only load a subset of the whole dataset)")
params = parser.parse_args()
# check parameters
check_attr(params)
assert len(params.name.strip()) > 0
assert params.n_skip <= params.n_layers - 1
assert params.deconv_method in ['convtranspose', 'upsampling', 'pixelshuffle']
assert 0 <= params.smooth_label < 0.5
assert not params.ae_reload or os.path.isfile(params.ae_reload)
assert not params.lat_dis_reload or os.path.isfile(params.lat_dis_reload)
assert not params.ptc_dis_reload or os.path.isfile(params.ptc_dis_reload)
assert not params.clf_dis_reload or os.path.isfile(params.clf_dis_reload)
assert os.path.isfile(params.eval_clf)
assert params.lambda_lat_dis == 0 or params.n_lat_dis > 0
assert params.lambda_ptc_dis == 0 or params.n_ptc_dis > 0
assert params.lambda_clf_dis == 0 or params.n_clf_dis > 0
# initialize experiment / load dataset
logger = initialize_exp(params)
data, attributes = load_images(params)
train_data = DataSampler(data[0], attributes[0], params)
valid_data = DataSampler(data[1], attributes[1], params)
# build the model
ae = AutoEncoder(params).cuda()
lat_dis = LatentDiscriminator(params).cuda() if params.n_lat_dis else None
ptc_dis = PatchDiscriminator(params).cuda() if params.n_ptc_dis else None
clf_dis = Classifier(params).cuda() if params.n_clf_dis else None
eval_clf = torch.load(params.eval_clf).cuda().eval()
# trainer / evaluator
trainer = Trainer(ae, lat_dis, ptc_dis, clf_dis, train_data, params)
evaluator = Evaluator(ae, lat_dis, ptc_dis, clf_dis, eval_clf, valid_data, params)
for n_epoch in range(params.n_epochs):
logger.info('Starting epoch %i...' % n_epoch)
for n_iter in range(0, params.epoch_size, params.batch_size):
# latent discriminator training
for _ in range(params.n_lat_dis):
trainer.lat_dis_step()
# patch discriminator training
for _ in range(params.n_ptc_dis):
trainer.ptc_dis_step()
# classifier discriminator training
for _ in range(params.n_clf_dis):
trainer.clf_dis_step()
# autoencoder training
trainer.autoencoder_step()
# print training statistics
trainer.step(n_iter)
# run all evaluations / save best or periodic model
to_log = evaluator.evaluate(n_epoch)
trainer.save_best_periodic(to_log)
logger.info('End of epoch %i.\n' % n_epoch)