-
Notifications
You must be signed in to change notification settings - Fork 21
/
predict.py
50 lines (37 loc) · 1.44 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import cv2
import tensorflow as tf
import os
import sys
import numpy as np
from build_model import model_tools
model=model_tools()
model_folder='checkpoints'
image=sys.argv[1]
img=cv2.imread(image)
session=tf.Session()
img=cv2.resize(img,(100,100))
img=img.reshape(1,100,100,3)
labels = np.zeros((1, 2))
#Create a saver object to load the model
saver = tf.train.import_meta_graph(os.path.join(model_folder,'.meta'))
#restore the model from our checkpoints folder
#Uncomment the following line for running on a windows machine
#saver.restore(session,os.path.join(model_folder,'.\\'))
#The following line is for running on a linux machine, comment it out if running on a windows machine
saver.restore(session,os.path.join(model_folder,'./'))
#Create graph object for getting the same network architecture
graph = tf.get_default_graph()
#Get the last layer of the network by it's name which includes all the previous layers too
network = graph.get_tensor_by_name("add_4:0")
#create placeholders to pass the image and get output labels
im_ph= graph.get_tensor_by_name("Placeholder:0")
label_ph = graph.get_tensor_by_name("Placeholder_1:0")
#Inorder to make the output to be either 0 or 1.
network=tf.nn.sigmoid(network)
# Creating the feed_dict that is required to be fed to calculate y_pred
feed_dict_testing = {im_ph: img, label_ph: labels}
result=session.run(network, feed_dict=feed_dict_testing)
if result[0][0]:
print("Batman!")
else:
print("Superman!")