-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathtrain.py
71 lines (61 loc) · 2.08 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import tensorflow as tf
import hsr
import sys
import os
import time
if __name__ == '__main__':
# % python train.py folder_name
if len(sys.argv) < 2:
print 'Usage: python', sys.argv[0], 'training_data/'
sys.exit(1)
data_dir = sys.argv[1]
image_total = 0
for subdir, dirs, files in os.walk(data_dir):
for file_name in files:
if file_name.split('.')[-1] == 'png':
image_total += 1
checkpoint_dir = os.path.abspath('checkpoints')
checkpoint_prefix = os.path.join(checkpoint_dir, 'model')
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# Create graph
images, labels = hsr.read_images(data_dir)
logits = hsr.inference(images)
loss = hsr.loss(logits, labels)
train = hsr.training(loss, learning_rate=5e-2)
accuracy = hsr.evaluation(logits, labels)
# Run the graph
session = tf.Session()
session.run(tf.global_variables_initializer())
session.run(tf.local_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=session, coord=coord)
try:
batch_i = 1
total_batch = 0
epoch = 1
start_time = time.time()
while not coord.should_stop():
loss_value, acc_value, _ = session.run([
loss,
accuracy,
train])
elapsed_time = time.time() - start_time
print 'epoch:', epoch, 'batch:', batch_i, 'loss:', loss_value, 'accuracy:', acc_value, 'duration: %.3fs' % elapsed_time
batch_i += 1
total_batch += hsr.BATCH_SIZE
if total_batch >= image_total:
epoch += 1
total_batch = 0
batch_i = 1
saver.save(session, checkpoint_prefix)
start_time = time.time()
except tf.errors.OutOfRangeError:
print ''
print 'Done.'
except Exception as e:
coord.request_stop(e)
finally:
coord.request_stop()
coord.join(threads)