-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.py
79 lines (59 loc) · 1.88 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import argh
from time import time
from contextlib import contextmanager
import os
import random
import re
import sys
from collections import namedtuple
import logging
import daiquiri
daiquiri.setup(level=logging.DEBUG)
logger = daiquiri.getLogger(__name__)
_PATH_ = os.path.dirname(os.path.dirname(__file__))
if _PATH_ not in sys.path:
sys.path.append(_PATH_)
from config import FLAGS, HPS
from dataset import get_batch
@contextmanager
def timer(message: str):
tick = time()
yield
tock = time()
logger.info(f'{message}: {(tock - tick):.3f} seconds')
def train(flags=FLAGS, hps=HPS):
from Network import Net
net = Net(flags, hps)
for g in range(flags.global_epoch):
with timer(f'Global epoch #{g}'):
logger.debug(f'Start global epoch {g}')
net.train(porportion=0.01)
l, acc = net.test(porportion=0.1)
logger.debug(f'Finish global epoch {g}')
net.save_model(name=f'{l:.4f}-{acc:.4f}')
logger.info('All done')
def test(flags=FLAGS, hps=HPS):
from Network import Net
import numpy as np
net = Net(flags, hps)
load_model_path = './savedmodels/model-0.2513-0.9609.ckpt-1000'
net.restore_model(load_model_path)
for i in range(10):
masked_cigits = np.zeros((10, 16))
digit = np.random.randint(10)
masked_cigits[digit] = np.random.random(16)
net.reconstruct_img(masked_cigits)
if __name__ == "__main__":
if not os.path.exists('./train_log'):
os.makedirs('./train_log')
if not os.path.exists('./test_log'):
os.makedirs('./test_log')
if not os.path.exists('./savedmodels'):
os.makedirs('./savedmodels')
fn = {'train': lambda: train(),
'test': lambda: test()}
if fn.get(FLAGS.MODE, 0) != 0:
fn[FLAGS.MODE]()
else:
logger.info('Please choose a mode among "train", "test".')