-
Notifications
You must be signed in to change notification settings - Fork 13
/
test_net.py
55 lines (40 loc) · 1.64 KB
/
test_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import os
from my_transform import demension_reduce
from my_image_folder import ImageFolder
from torch.autograd import Variable
from my_transform import transform
from define_net import Net
from torch.autograd import Variable
if __name__ == '__main__':
path_ = os.path.abspath('.')
net = Net()
net.load_state_dict(torch.load(path_+'/net_relu.pth')) # your net
testset = ImageFolder(path_+'/test_set/',transform) # your test set
f = open(path_+'/result_relu.txt','w') # where to write answer
tys = {} # map typhoon to its max wind
tys_time = {} # map typhoon-time to wind
for i in range(0,testset.__len__()):
image, actual = testset.__getitem__(i)
image = image.expand(1,image.size(0),image.size(1),image.size(2)) # a batch with 1 sample
name = testset.__getitemName__(i)
output = net(Variable(image))
wind = output.data[0][0] # output is a 1*1 tensor
name = name.split('_')
tid = name[0]
if tys.has_key(tid):
if tys[tid] < wind:
tys[tid] = wind
else :
tys[tid] = wind
tid_time = name[0]+'_'+name[1]+'_'+name[2]+'_'+name[3]
tys_time[tid_time] = wind
if i % 100 == 99 :
print 'have processed ',i+1,' samples.'
tys = sorted(tys.iteritems(),key=lambda asd:asd[1],reverse=True)
for ty in tys:
print ty # show the sort of typhoons' wind
tys_time = sorted(tys_time.iteritems(),key=lambda asd:asd[0],reverse=False)
for ty in tys_time:
f.write(str(ty)+'\n') # record all result by time
f.close()