-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathRun.py
121 lines (91 loc) · 3.94 KB
/
Run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
This script utilises the ground truth label's 2D bounding box to
crop out the the points of interest and feed it into the model so that
it can predict a 3D bounding box for each of the 2D detections
The script will plot the results of the 3D bounding box onto the image
and display them alongside the groundtruth image and it's 3D bounding box.
This is to help with qualitative assesment.
Images to be evaluated should be placed in Kitti/validation/image_2
FLAGS:
--hide-imgs
Hides Display of ground truth and bounding box
"""
import os
import cv2
import errno
import argparse
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
import torchvision.models as models
from lib.DataUtils import *
from lib.Utils import *
from lib import Model, ClassAverages
def main():
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
weights_path = os.path.abspath(os.path.dirname(__file__)) + '/weights'
weight_list = [x for x in sorted(os.listdir(weights_path)) if x.endswith('.pkl')]
if len(weight_list) == 0:
print('We could not find any model weight to load, please train the model first!')
exit()
else:
print ('Using previous model weights %s'%weight_list[-1])
my_vgg = models.vgg19_bn(pretrained=True)
model = Model.Model(features=my_vgg.features, bins=2)
if use_cuda:
checkpoint = torch.load(weights_path + '/%s'%weight_list[-1])
else:
checkpoint = torch.load(weights_path + '/%s'%weight_list[-1],map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Load Test Images from validation folder
dataset = Dataset(os.path.abspath(os.path.dirname(__file__)) + '/Kitti/validation/')
all_images = dataset.all_objects()
print ("Length of validation data",len(all_images))
averages = ClassAverages.ClassAverages()
all_images = dataset.all_objects()
print ("Model is commencing predictions.....")
for key in sorted(all_images.keys()):
data = all_images[key]
truth_img = data['Image']
img = np.copy(truth_img)
imgGT = np.copy(truth_img)
objects = data['Objects']
cam_to_img = data['Calib']
for object in objects:
label = object.label
theta_ray = object.theta_ray
input_img = object.img
input_tensor = torch.zeros([1,3,224,224])
input_tensor[0,:,:,:] = input_img
input_tensor.cuda()
[orient, conf, dim] = model(input_tensor)
orient = orient.cpu().data.numpy()[0, :, :]
conf = conf.cpu().data.numpy()[0, :]
dim = dim.cpu().data.numpy()[0, :]
dim += averages.get_item(label['Class'])
argmax = np.argmax(conf)
orient = orient[argmax, :]
cos = orient[0]
sin = orient[1]
alpha = np.arctan2(sin, cos)
alpha += dataset.angle_bins[argmax]
alpha -= np.pi
location = plot_regressed_3d_bbox_2(img, truth_img, cam_to_img, label['Box_2D'], dim, alpha, theta_ray)
locationGT = plot_regressed_3d_bbox_2(imgGT, truth_img, cam_to_img, label['Box_2D'], label['Dimensions'], label['Alpha'], theta_ray)
# print('Estimated pose: %s'%location)
# print('Truth pose: %s'%label['Location'])
# print('-------------')
if not FLAGS.hide_imgs:
numpy_vertical = np.concatenate((truth_img,imgGT, img), axis=0)
cv2.imshow('2D detection on top, 3D Ground Truth on middle , 3D prediction on bottom', numpy_vertical)
cv2.waitKey(0)
print ("Finished.")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--hide-imgs", action="store_true",
help="Hide display of visual results")
FLAGS = parser.parse_args()
main()