-
Notifications
You must be signed in to change notification settings - Fork 108
/
Copy pathtest_voc_video.py
96 lines (86 loc) · 4.46 KB
/
test_voc_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
@author: Thang Nguyen <nhthang1009@gmail.com>
"""
import argparse
import pickle
import cv2
import numpy as np
from src.utils import *
from src.yolo_net import Yolo
CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
'tvmonitor']
def get_args():
parser = argparse.ArgumentParser("You Only Look Once: Unified, Real-Time Object Detection")
parser.add_argument("--image_size", type=int, default=448, help="The common width and height for all images")
parser.add_argument("--conf_threshold", type=float, default=0.35)
parser.add_argument("--nms_threshold", type=float, default=0.5)
parser.add_argument("--test_set", type=str, default="test",
help="For both VOC2007 and 2012, you could choose 3 different datasets: train, trainval and val. Additionally, for VOC2007, you could also pick the dataset name test")
parser.add_argument("--year", type=str, default="2007", help="The year of dataset (2007 or 2012)")
parser.add_argument("--data_path", type=str, default="data/VOCdevkit", help="the root folder of dataset")
parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="model")
parser.add_argument("--pre_trained_model_path", type=str, default="trained_models/whole_model_trained_yolo_voc")
parser.add_argument("--input", type=str, default="test_videos/input.mp4")
parser.add_argument("--output", type=str, default="test_videos/output_voc.mp4")
args = parser.parse_args()
return args
def test(opt):
if torch.cuda.is_available():
if opt.pre_trained_model_type == "model":
model = torch.load(opt.pre_trained_model_path)
else:
model = Yolo(20)
model.load_state_dict(torch.load(opt.pre_trained_model_path))
else:
if opt.pre_trained_model_type == "model":
model = torch.load(opt.pre_trained_model_path, map_location=lambda storage, loc: storage)
else:
model = Yolo(20)
model.load_state_dict(torch.load(opt.pre_trained_model_path, map_location=lambda storage, loc: storage))
model.eval()
colors = pickle.load(open("src/pallete", "rb"))
cap = cv2.VideoCapture(opt.input)
out = cv2.VideoWriter(opt.output, cv2.VideoWriter_fourcc(*"MJPG"), int(cap.get(cv2.CAP_PROP_FPS)),
(int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))
while cap.isOpened():
flag, image = cap.read()
output_image = np.copy(image)
if flag:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
else:
break
height, width = image.shape[:2]
image = cv2.resize(image, (opt.image_size, opt.image_size))
image = np.transpose(np.array(image, dtype=np.float32), (2, 0, 1))
image = image[None, :, :, :]
width_ratio = float(opt.image_size) / width
height_ratio = float(opt.image_size) / height
data = Variable(torch.FloatTensor(image))
if torch.cuda.is_available():
data = data.cuda()
with torch.no_grad():
logits = model(data)
predictions = post_processing(logits, opt.image_size, CLASSES, model.anchors, opt.conf_threshold,
opt.nms_threshold)
if len(predictions) != 0:
predictions = predictions[0]
for pred in predictions:
xmin = int(max(pred[0] / width_ratio, 0))
ymin = int(max(pred[1] / height_ratio, 0))
xmax = int(min((pred[0] + pred[2]) / width_ratio, width))
ymax = int(min((pred[1] + pred[3]) / height_ratio, height))
color = colors[CLASSES.index(pred[5])]
cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), color, 2)
text_size = cv2.getTextSize(pred[5] + ' : %.2f' % pred[4], cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
cv2.rectangle(output_image, (xmin, ymin), (xmin + text_size[0] + 3, ymin + text_size[1] + 4), color, -1)
cv2.putText(
output_image, pred[5] + ' : %.2f' % pred[4],
(xmin, ymin + text_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1,
(255, 255, 255), 1)
out.write(output_image)
cap.release()
out.release()
if __name__ == "__main__":
opt = get_args()
test(opt)