-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
72 lines (60 loc) · 2.69 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from argparse import ArgumentParser
import numpy as np
import torch
from torch import optim
from torch.nn import MSELoss
from torchvision.transforms import Resize
from tqdm import trange, tqdm
from Dataloader import kitti_data_loader, vkitti2_data_loader
from Models import AutoEncoder, Loss
def train(args):
if args.dataset == 'kitti_large':
train_loader = kitti_data_loader.get_kitti_loader(
'Dataloader/kitti', split='train-large',
batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
test_loader = kitti_data_loader.get_kitti_loader(
'Dataloader/kitti', split='test',
batch_size=args.batch_size, num_workers=args.num_workers)
elif args.dataset == 'kitti_small':
train_loader = kitti_data_loader.get_kitti_loader(
'Dataloader/kitti', split='train-small',
batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
test_loader = kitti_data_loader.get_kitti_loader(
'Dataloader/kitti', split='test',
batch_size=args.batch_size, num_workers=args.num_workers)
elif args.dataset == 'vkitti2':
train_loader = vkitti2_data_loader.get_vkitti2_loader(
'Dataloader/vkitti2', split='train',
batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
test_loader = vkitti2_data_loader.get_vkitti2_loader(
'Dataloader/vkitti2', split='test',
batch_size=args.batch_size, num_workers=args.num_workers)
else:
raise NotImplementedError
net = AutoEncoder.AutoEncoder().to(args.device)
state_dict = torch.load(args.load_ckpt)
net.load_state_dict(state_dict)
net = net.to(args.device)
criterion = Loss.MaskedMSE()
test_loss = []
with torch.no_grad():
net.eval()
for i, sample in enumerate(test_loader):
inputs = sample['image'].to(args.device)
targets = sample['depth'].to(args.device)
outputs = net(inputs)
loss = criterion(outputs, targets / 80, targets >= 0)
test_loss.append(loss.to('cpu').tolist())
print(f'Test masked MSE {np.mean(test_loss)}')
def parse_args():
parser = ArgumentParser()
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--dataset', type=str, required=True)
parser.add_argument('--model_type', type=str, default='autoencoder')
parser.add_argument('--load_ckpt', type=str, required=True)
parser.add_argument('--batch_size', type=int, default=4)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
train(args=args)