-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheck_test_images.py
60 lines (52 loc) · 1.7 KB
/
check_test_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
To check results for the test images (2007_test.txt) file using the best model.
"""
import torch
import cv2
import argparse
from utils import detect, draw_boxes
from config import S, B, C
from models.create_model import create_model
parser = argparse.ArgumentParser()
parser.add_argument(
'-t', '--threshold', help='confidence threshold to filter detected boxes',
default=0.25, type=float
)
parser.add_argument(
'-m', '--model', default='yolov1_vgg11',
help='the model to train with, see models/create_model.py for all \
available models'
)
parser.add_argument(
'-w', '--weights', default='best.pth',
help='path to model weight'
)
parser.add_argument(
'-d', '--device',
default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
help='computing device'
)
args = vars(parser.parse_args())
device = args['device']
# Load model and weights.
create_model = create_model[args['model']]
model = create_model(C, S, B, pretrained=False).to(device)
print('Loading trained YOLO model weights...\n')
checkpoint = torch.load(args['weights'], map_location=args['device'])
model.load_state_dict(checkpoint)
model.to(device).eval()
with open('2007_test.txt', 'r') as f:
test_image_paths = f.readlines()
f.close()
for image_path in test_image_paths:
# Read and prepare image.
image = cv2.imread(f"{image_path[0:-1]}")
# image = cv2.resize(image, (448, 448))
orig_image = image.copy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Run detection.
nms_boxes, scores = detect(model, image, args['threshold'], S=S, device=device)
# print(nms_boxes)
result = draw_boxes(image, nms_boxes)
cv2.imshow('Result', result)
cv2.waitKey(0)