-
Notifications
You must be signed in to change notification settings - Fork 0
/
InferenceManager.py
96 lines (66 loc) · 2.97 KB
/
InferenceManager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import configparser
import cv2
import numpy as np
import tensorflow as tf
class InferenceManager(object):
def __init__(self, detection_only=False):
super(InferenceManager, self).__init__()
self.detection_only = detection_only
# Config parser
cfg = configparser.ConfigParser()
cfg.read('config.ini')
detection_model_path = cfg['MODELS']['DetectionPath']
classification_model_path = cfg['MODELS']['ClassificationPath']
# Loading frozen hand detection model
self._dtc_graph, self._dtc_session = self._load_detection_graph(detection_model_path)
if not detection_only:
# Loading Keras gesture classification model
self._cls_model, self._cls_graph, self._cls_session = self._load_classification_graph(classification_model_path)
# Load frozen hand detection and Keras gesture classification model
def _load_detection_graph(self, path):
graph = tf.Graph()
with graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
session = tf.Session(graph=graph)
return graph, session
def _load_classification_graph(self, path):
import os; os.environ['KERAS_BACKEND'] = 'tensorflow'
import keras
graph = tf.Graph()
with graph.as_default():
session = tf.Session()
with session.as_default():
model = keras.models.load_model(path)
graph = tf.get_default_graph()
return model, graph, session
# Hand detector
def detect(self, image_np):
# Input Tensor
image_tensor = self._dtc_graph.get_tensor_by_name('image_tensor:0')
# Output Tensors
dtc_boxes = self._dtc_graph.get_tensor_by_name('detection_boxes:0')
dtc_scores = self._dtc_graph.get_tensor_by_name('detection_scores:0')
# Perform inference
image_np_expanded = np.expand_dims(image_np, axis=0)
feed_dict = {image_tensor: image_np_expanded}
boxes, scores = self._dtc_session.run([dtc_boxes, dtc_scores], feed_dict)
return np.squeeze(boxes), np.squeeze(scores)
def classify(self, image):
if self.detection_only:
raise Exception('Classification model not loaded for this instance.')
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
image = cv2.flip(image, 1)
# Reshape
res = cv2.resize(image, (28,28), interpolation=cv2.INTER_AREA)
# Convert to float values between 0. and 1.
res = res.astype(dtype="float64")
res = res / 255
res = np.reshape(res, (1, 28, 28, 1))
# Perform Inference
with self._cls_graph.as_default():
with self._cls_session.as_default():
return self._cls_model.predict(res)[0]