forked from ZZUTK/Face-Aging-CAAE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
44 lines (34 loc) · 1.44 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import tensorflow as tf
from FaceAging import FaceAging
flags = tf.app.flags
flags.DEFINE_integer(flag_name='epoch', default_value=50, docstring='number of epochs')
flags.DEFINE_boolean(flag_name='is_train', default_value=True, docstring='training mode')
flags.DEFINE_string(flag_name='dataset', default_value='UTKFace', docstring='dataset name')
flags.DEFINE_string(flag_name='savedir', default_value='save', docstring='dir for saving training results')
flags.DEFINE_string(flag_name='testdir', default_value='None', docstring='dir for testing images')
FLAGS = flags.FLAGS
def main(_):
# print settings
import pprint
pprint.pprint(FLAGS.__flags)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as session:
model = FaceAging(
session, # TensorFlow session
is_training=FLAGS.is_train, # flag for training or testing mode
save_dir=FLAGS.savedir, # path to save checkpoints, samples, and summary
dataset_name=FLAGS.dataset # name of the dataset in the folder ./data
)
if FLAGS.is_train:
print '\n\tTraining Mode'
model.train(
num_epochs=FLAGS.epoch, # number of epochs
)
else:
print '\n\tTesting Mode'
model.custom_test(
testing_samples_dir=FLAGS.testdir + '/*jpg'
)
if __name__ == '__main__':
tf.app.run()