-
Notifications
You must be signed in to change notification settings - Fork 419
/
test_detect.py
77 lines (69 loc) · 2.65 KB
/
test_detect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import os
import time
import numpy as np
from importlib import import_module
import shutil
from utils import *
import sys
from split_combine import SplitComb
import torch
from torch.nn import DataParallel
from torch.backends import cudnn
from torch.utils.data import DataLoader
from torch import optim
from torch.autograd import Variable
from layers import acc
def test_detect(data_loader, net, get_pbb, save_dir, config,n_gpu):
start_time = time.time()
net.eval()
split_comber = data_loader.dataset.split_comber
for i_name, (data, target, coord, nzhw) in enumerate(data_loader):
s = time.time()
target = [np.asarray(t, np.float32) for t in target]
lbb = target[0]
nzhw = nzhw[0]
name = data_loader.dataset.filenames[i_name].split('-')[0].split('/')[-1]
shortname = name.split('_clean')[0]
data = data[0][0]
coord = coord[0][0]
isfeat = False
if 'output_feature' in config:
if config['output_feature']:
isfeat = True
n_per_run = n_gpu
print(data.size())
splitlist = range(0,len(data)+1,n_gpu)
if splitlist[-1]!=len(data):
splitlist.append(len(data))
outputlist = []
featurelist = []
for i in range(len(splitlist)-1):
input = Variable(data[splitlist[i]:splitlist[i+1]], volatile = True).cuda()
inputcoord = Variable(coord[splitlist[i]:splitlist[i+1]], volatile = True).cuda()
if isfeat:
output,feature = net(input,inputcoord)
featurelist.append(feature.data.cpu().numpy())
else:
output = net(input,inputcoord)
outputlist.append(output.data.cpu().numpy())
output = np.concatenate(outputlist,0)
output = split_comber.combine(output,nzhw=nzhw)
if isfeat:
feature = np.concatenate(featurelist,0).transpose([0,2,3,4,1])[:,:,:,:,:,np.newaxis]
feature = split_comber.combine(feature,sidelen)[...,0]
thresh = -3
pbb,mask = get_pbb(output,thresh,ismask=True)
if isfeat:
feature_selected = feature[mask[0],mask[1],mask[2]]
np.save(os.path.join(save_dir, shortname+'_feature.npy'), feature_selected)
#tp,fp,fn,_ = acc(pbb,lbb,0,0.1,0.1)
#print([len(tp),len(fp),len(fn)])
print([i_name,shortname])
e = time.time()
np.save(os.path.join(save_dir, shortname+'_pbb.npy'), pbb)
np.save(os.path.join(save_dir, shortname+'_lbb.npy'), lbb)
end_time = time.time()
print('elapsed time is %3.2f seconds' % (end_time - start_time))
print
print