-
Notifications
You must be signed in to change notification settings - Fork 0
/
make_predictions.py
110 lines (89 loc) · 3.4 KB
/
make_predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Classify all the images in a holdout set.
"""
import pickle
import sys
import tensorflow as tf
from tqdm import tqdm
def get_labels():
"""Return a list of our trained labels so we can
test our training accuracy. The file is in the
format of one label per line, in the same order
as the predictions are made. The order can change
between training runs."""
with open("./inception/retrained_labels.txt", 'r') as fin:
labels = [line.rstrip('\n') for line in fin]
return labels
def predict_on_frames(frames, batch):
"""Given a list of frames, predict all their classes."""
# Unpersists graph from file
with tf.gfile.FastGFile("./inception/retrained_graph.pb", 'rb') as fin:
graph_def = tf.GraphDef()
graph_def.ParseFromString(fin.read())
_ = tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
frame_predictions = []
image_path = 'images/' + batch + '/'
pbar = tqdm(total=len(frames))
for i, frame in enumerate(frames):
filename = frame[0]
label = frame[1]
# Get the image path.
image = image_path + filename + '.jpg'
# Read in the image_data
image_data = tf.gfile.FastGFile(image, 'rb').read()
try:
predictions = sess.run(
softmax_tensor,
{'DecodeJpeg/contents:0': image_data}
)
prediction = predictions[0]
except KeyboardInterrupt:
print("You quit with ctrl+c")
sys.exit()
except:
print("Error making prediction, continuing.")
continue
# Save the probability that it's each of our classes.
frame_predictions.append([prediction, label])
if i > 0 and i % 10 == 0:
pbar.update(10)
pbar.close()
return frame_predictions
def get_accuracy(predictions, labels):
"""After predicting on each batch, check that batch's
accuracy to make sure things are good to go. This is
a simple accuracy metric, and so doesn't take confidence
into account, which would be a better metric to use to
compare changes in the model."""
correct = 0
for frame in predictions:
# Get the highest confidence class.
this_prediction = frame[0].tolist()
this_label = frame[1]
max_value = max(this_prediction)
max_index = this_prediction.index(max_value)
predicted_label = labels[max_index]
# Now see if it matches.
if predicted_label == this_label:
correct += 1
accuracy = correct / len(predictions)
return accuracy
def main():
batches = ['1']
labels = get_labels()
for batch in batches:
print("Doing batch %s" % batch)
with open('data/labeled-frames-' + batch + '.pkl', 'rb') as fin:
frames = pickle.load(fin)
# Predict on this batch and get the accuracy.
predictions = predict_on_frames(frames, batch)
accuracy = get_accuracy(predictions, labels)
print("Batch accuracy: %.5f" % accuracy)
# Save it.
with open('data/predicted-frames-' + batch + '.pkl', 'wb') as fout:
pickle.dump(predictions, fout)
print("Done.")
if __name__ == '__main__':
main()