-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
129 lines (101 loc) · 4.61 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import cv2
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from models import YOLOv3
from utils import postprocessing
from config import config
from utils.load_yolov3_weights import load_yolov3_weights
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('input', help='Input image/video.')
parser.add_argument('-c', '--classes', nargs='+', help='Predict the specific class only.')
args = parser.parse_args()
with tf.variable_scope('model'):
print("Constructing computational graph...")
model = YOLOv3(config)
print("Done")
print("Loading weights...")
global_vars = tf.global_variables(scope='model')
assign_ops = load_yolov3_weights(global_vars, config['WEIGHTS_PATH'])
[print(n.name) for n in tf.get_default_graph().as_graph_def().node]
print("Done")
print("=============================================")
print("Loading class names...")
classes = []
colours = {}
f = open(config['CLASS_PATH'], 'r').read().splitlines()
for i, line in enumerate(f):
classes.append(line)
colours[i] = tuple([int(z) for z in np.random.uniform(0, 255, size=3)])
print(classes)
print("Done")
print("=============================================")
print("Running YOLOv3...")
def resize_bbox_to_original(original_image, bbox):
"""Resize a detected bounding box to fit the original image.
:param original_image: The original image.
:param bbox: The bounding box.
:return: The resized bounding box.
"""
original_size = np.array(original_image.shape[:2][::-1])
resized_size = np.array([config['IMAGE_SIZE'], config['IMAGE_SIZE']])
ratio = original_size/resized_size
bbox = bbox.reshape(2, 2)*ratio
bbox = list(bbox.reshape(-1))
bbox = [int(z) for z in bbox]
return bbox
def label_bboxes(original_image, bbox, class_id, score):
"""Draw a bounding box on the original image with a label.
:param original_image: The original iamge.
:param bbox: The bounding box.
:param class_id: The class ID of the bounding box.
:param score: The objectness score.
:return: The labeled image.
"""
x1, y1, x2, y2 = resize_bbox_to_original(original_image, bbox)
label = '{}: {}%'.format(classes[class_id], int(score*100))
cv2.rectangle(frame, (x1, y1), (x2, y2), colours[class_id], 2)
text_size, baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.75, 2)
w, h = text_size
cv2.rectangle(frame, (x1, y1), (x1 + w, y1 - h), colours[class_id], cv2.FILLED)
cv2.putText(frame, label, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 0), 2, cv2.LINE_AA)
return original_image
with tf.Session() as sess:
sess.run(assign_ops)
if args.input[-4:] == '.mp4' or args.input[-4:] == '.avi':
video = cv2.VideoCapture(args.input)
video_length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
pbar = tqdm(unit='frame', total=video_length)
while video.isOpened():
ret, frame = video.read()
if ret is not True:
video.release()
break
resized_frame = cv2.resize(frame, (config['IMAGE_SIZE'], config['IMAGE_SIZE']))
detected_bboxes = sess.run(model.outputs, feed_dict={model.inputs: np.expand_dims(resized_frame, axis=0)})
filtered_bboxes = postprocessing.nms(detected_bboxes, conf_thresh=config['CONF_THRESH'],
iou_thresh=config['IOU_THRESH'])
for class_id, v in filtered_bboxes.items():
if args.classes is None or classes[class_id] in args.classes:
for detection in v:
label_bboxes(frame, detection['bbox'], class_id, detection['score'])
pbar.update(1)
cv2.imshow('Result', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
video.release()
break
print("Done")
else:
frame = cv2.imread(args.input)
resized_frame = cv2.resize(frame, (config['IMAGE_SIZE'], config['IMAGE_SIZE']))
detected_bboxes = sess.run(model.outputs, feed_dict={model.inputs: np.expand_dims(resized_frame, axis=0)})
filtered_bboxes = postprocessing.nms(detected_bboxes, conf_thresh=config['CONF_THRESH'],
iou_thresh=config['IOU_THRESH'])
for class_id, v in filtered_bboxes.items():
if args.classes is None or classes[class_id] in args.classes:
for detection in v:
label_bboxes(frame, detection['bbox'], class_id, detection['score'])
print("Done")
cv2.imshow('Result', frame)
cv2.waitKey(0)