-
Notifications
You must be signed in to change notification settings - Fork 2
/
makeLPImages.py
148 lines (123 loc) · 4.58 KB
/
makeLPImages.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
from torch.autograd import Variable
from torch.utils import data
from model import PB_FCN, LabelProp
from dataset import SSDataSet
from transform import Scale, ToLabel, Colorize, ToYUV, labelToPred
from torchvision.transforms import Compose, Normalize, ToTensor
from PIL import Image
import progressbar
from paramSave import saveParams
import argparse
import time
parser = argparse.ArgumentParser()
parser.add_argument("--finetuned", help="Use finetuned net and dataset",
action="store_true")
parser.add_argument("--pruned", help="Use pruned net",
action="store_true")
args = parser.parse_args()
fineTune = args.finetuned
pruned = args.pruned
fineTuneStr = "Finetuned" if fineTune else ""
pruneStr = "Pruned" if pruned else ""
scale = 4
input_transform = Compose([
Scale(scale, Image.BILINEAR),
ToYUV(),
ToTensor(),
Normalize([.5, .0, .0], [.5, .5, .5]),
])
target_transform = Compose([
Scale(scale, Image.NEAREST),
ToTensor(),
ToLabel(),
])
labSize = (480.0/scale, 640.0/scale)
outSize = 1.0/(labSize[0] * labSize[1])
batchSize = 1
root = "./data/"
outDir = "./output/"
if fineTune:
outDir = "./output/FinetuneHorizon/"
root = "./data/FinetuneHorizon"
valloader = data.DataLoader(SSDataSet(root, split="val", img_transform=input_transform,
label_transform=target_transform),
batch_size=batchSize, shuffle=False)
numClass = 5
kernelSize = 1
numPlanes = 32
model = PB_FCN(numPlanes, numClass, kernelSize, False)
modelLP = LabelProp(numClass,numPlanes)
mapLoc = {'cuda:0': 'cpu'}
if torch.cuda.is_available():
model = model.cuda()
mapLoc = None
stateDict = torch.load("./pth/bestModelSeg" + fineTuneStr + pruneStr + ".pth", map_location=mapLoc)
model.load_state_dict(stateDict)
stateDict = torch.load("./pth/bestModelLP" + fineTuneStr + pruneStr + ".pth", map_location=mapLoc)
modelLP.load_state_dict(stateDict)
saveParams("./weightsLP", modelLP.cpu())
running_acc = 0.0
imgCnt = 0
conf = torch.zeros(numClass, numClass)
IoU = torch.zeros(numClass)
labCnts = torch.zeros(numClass)
model.eval()
#print model
t = 0
bar = progressbar.ProgressBar(0, len(valloader), redirect_stdout=False)
for i, (images, labels) in enumerate(valloader):
if torch.cuda.is_available():
images = Variable(images.cuda())
labels = Variable(labels.cuda())
inputs = Variable(torch.cuda.FloatTensor(batchSize, 8, 120, 160))
else:
images = Variable(images)
labels = Variable(labels)
inputs = Variable(torch.FloatTensor(batchSize, 8, 120, 160))
beg = time.clock()
pred = model(images)
_, predClass = torch.max(pred, 1)
for j in range(1):
lab = labelToPred(labels, numClass)
chY = torch.unsqueeze(torch.unsqueeze(images[0][0], 0), 0)
inputs = torch.cat( [chY, chY, chY-chY,lab], 1 )
pred = modelLP(inputs)
_, predClass = torch.max(pred, 1)
t += time.clock() - beg
running_acc += torch.sum(predClass.data == labels.data) * outSize * 100
bSize = images.data.size()[0]
for j in range(bSize):
img = Image.fromarray(Colorize(predClass.data[j]).permute(1, 2, 0).numpy().astype('uint8'))
img.save(outDir + "%d.png" % (imgCnt + j))
imgCnt += bSize
maskPred = torch.zeros(numClass, bSize, int(labSize[0]), int(labSize[1])).long()
maskLabel = torch.zeros(numClass, bSize, int(labSize[0]), int(labSize[1])).long()
for currClass in range(numClass):
maskPred[currClass] = predClass.data == currClass
maskLabel[currClass] = labels.data == currClass
for labIdx in range(numClass):
labCnts[labIdx] += torch.sum(maskLabel[labIdx])
for predIdx in range(numClass):
inter = torch.sum(maskPred[predIdx] & maskLabel[labIdx])
conf[(predIdx, labIdx)] += inter
if labIdx == predIdx:
if labIdx == predIdx:
union = torch.sum(maskPred[predIdx] | maskLabel[labIdx])
if union == 0:
IoU[labIdx] += 1
else:
IoU[labIdx] += float(inter)/(float(union))
bar.update(i)
bar.finish()
t = t/imgCnt*1000
for labIdx in range(numClass):
for predIdx in range(numClass):
conf[(predIdx, labIdx)] /= (labCnts[labIdx] / 100.0)
meanClassAcc = 0.0
for j in range(numClass):
meanClassAcc += conf[(j, j)] / numClass
meanIoU = torch.sum(IoU/imgCnt)/numClass*100
print("Validation Pixel Acc: %.2f Mean Class Acc: %.2f Mean IoU: %.2f" % (running_acc / (imgCnt), meanClassAcc, meanIoU))
print conf
print t