-
Notifications
You must be signed in to change notification settings - Fork 226
/
run_inference.py
84 lines (63 loc) · 3.34 KB
/
run_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from imageio import imread, imsave
from skimage.transform import resize
from skimage.util import img_as_float
import numpy as np
from path import Path
import argparse
from tqdm import tqdm
from models import DispNetS
from utils import tensor2array
parser = argparse.ArgumentParser(description='Inference script for DispNet learned with \
Structure from Motion Learner inference on KITTI and CityScapes Dataset',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--output-disp", action='store_true', help="save disparity img")
parser.add_argument("--output-depth", action='store_true', help="save depth img")
parser.add_argument("--pretrained", required=True, type=str, help="pretrained DispNet path")
parser.add_argument("--img-height", default=128, type=int, help="Image height")
parser.add_argument("--img-width", default=416, type=int, help="Image width")
parser.add_argument("--no-resize", action='store_true', help="no resizing is done")
parser.add_argument("--dataset-list", default=None, type=str, help="Dataset list file")
parser.add_argument("--dataset-dir", default='.', type=str, help="Dataset directory")
parser.add_argument("--output-dir", default='output', type=str, help="Output directory")
parser.add_argument("--img-exts", default=['png', 'jpg', 'bmp'], nargs='*', type=str, help="images extensions to glob")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
@torch.no_grad()
def main():
args = parser.parse_args()
if not(args.output_disp or args.output_depth):
print('You must at least output one value !')
return
disp_net = DispNetS().to(device)
weights = torch.load(args.pretrained)
disp_net.load_state_dict(weights['state_dict'])
disp_net.eval()
dataset_dir = Path(args.dataset_dir)
output_dir = Path(args.output_dir)
output_dir.makedirs_p()
if args.dataset_list is not None:
with open(args.dataset_list, 'r') as f:
test_files = [dataset_dir/file for file in f.read().splitlines()]
else:
test_files = sum([list(dataset_dir.walkfiles('*.{}'.format(ext))) for ext in args.img_exts], [])
print('{} files to test'.format(len(test_files)))
for file in tqdm(test_files):
img = img_as_float(imread(file))
h,w,_ = img.shape
if (not args.no_resize) and (h != args.img_height or w != args.img_width):
img = resize(img, (args.img_height, args.img_width))
img = np.transpose(img, (2, 0, 1))
tensor_img = torch.from_numpy(img.astype(np.float32)).unsqueeze(0)
tensor_img = ((tensor_img - 0.5)/0.5).to(device)
output = disp_net(tensor_img)[0]
file_path, file_ext = file.relpath(args.dataset_dir).splitext()
file_name = '-'.join(file_path.splitall()[1:])
if args.output_disp:
disp = (255*tensor2array(output, max_value=None, colormap='bone')).astype(np.uint8)
imsave(output_dir/'{}_disp{}'.format(file_name, file_ext), np.transpose(disp, (1,2,0)))
if args.output_depth:
depth = 1/output
depth = (255*tensor2array(depth, max_value=10, colormap='rainbow')).astype(np.uint8)
imsave(output_dir/'{}_depth{}'.format(file_name, file_ext), np.transpose(depth, (1,2,0)))
if __name__ == '__main__':
main()