-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtfdetector.py
67 lines (51 loc) · 2.15 KB
/
tfdetector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
Tensorflow object detector
"""
# pylint: disable=C0103,C0301,R0902
import time
from typing import List
import tensorflow as tf
import entities as e
from objdetector import ObjectDetector
class TensorFlowDetector:
"""
Tensor flow object detector
"""
def __init__(self, config, logger):
self._logger = logger
self._threshold = config['threshold']
modelfile = config['model']
with tf.io.gfile.GFile(modelfile, 'rb') as f:
graph = tf.compat.v1.GraphDef()
graph.ParseFromString(f.read())
self._logger.debug('TF model placeholders:')
phnodes = [n.name + ' => ' + n.op for n in graph.node if n.op in 'Placeholder']
for node in phnodes:
self._logger.debug(node)
self._logger.debug('TF model placeholders end.')
with tf.compat.v1.Graph().as_default() as defgraph:
#name='' is *absolutely* required, or you've got lots of weird errors (pyece of shithon)
tf.compat.v1.import_graph_def(graph, name='')
self._session = tf.compat.v1.Session(graph=defgraph)
self._image_tensor = defgraph.get_tensor_by_name('image_tensor:0')
self._boxes = defgraph.get_tensor_by_name('detection_boxes:0')
self._scores = defgraph.get_tensor_by_name('detection_scores:0')
self._classes = defgraph.get_tensor_by_name('detection_classes:0')
self._num_detections = defgraph.get_tensor_by_name('num_detections:0')
def stop(self):
"""
Destruction
"""
self._session.close()
def detectObjects(self, img) -> List[e.DetectedObject]:
"""
Implementation of detector interface
"""
h, w, _ = img.shape
tstart = time.time()
(boxes, scores, classes, _) = self._session.run(
[self._boxes, self._scores, self._classes, self._num_detections],
feed_dict={self._image_tensor: img.reshape(1, h, w, 3)})
self._logger.debug(f'TF model inferring time: {time.time() - tstart}')
result = zip(classes[0], scores[0], boxes[0])
return ObjectDetector.getDetectedObjectsCollection(result, h, w, self._threshold)