-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
70 lines (59 loc) · 2.55 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
Script to evaluate ensemble of WideResNet models created by augmenting test data.
Each predictions is made multiple times on randomly modified images. Then, the average of all
predictions is taken as a final prediction.
"""
import torch
import argparse
import matplotlib.pyplot as plt; plt.style.use("fivethirtyeight")
from tqdm import tqdm, trange
from data import get_kaggle_testloader
from models import WideResNet22
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--path", type=str, required=True, help="Path to the Kaggle test folder.")
parser.add_argument("-s", "--state_path", type=str, required=True)
parser.add_argument("--augment", action="store_true", help="Whether to use data augmentation.")
parser.add_argument("-e", "--ensemble", type=int, default=1, help="Number of ensembles to use")
parser.add_argument("--save_path", type=str, default="submission.csv", help="Path in which to save submission.")
parser.add_argument("-bs", "--batch_size", type=int, default=1000, help="Test batch size.")
args = parser.parse_args()
print(args)
classes = ('airplane', 'automobile', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print("Reading kaggle data into memory...")
testloader = get_kaggle_testloader(args.path, augment=args.augment, batch_size=args.batch_size)
device = torch.device("cuda")
# Reading net
net = WideResNet22()
net.load_state_dict(torch.load(args.state_path, map_location=device))
net.to(device).eval()
print("Making predictions...")
# Making predictions
ensemble_preds = []
for e in trange(args.ensemble):
indexes = []
preds = []
n_correct = 0
total = 0
with torch.no_grad():
for i, (x, idx) in enumerate(tqdm(testloader, leave=False)):
x = x.to(device)
y_pred = net(x)
preds.append(y_pred)
indexes.append(idx)
y_pred = torch.cat(preds)
indexes = torch.cat(indexes).numpy()
ensemble_preds.append(y_pred)
mean_pred = torch.mean(torch.stack(ensemble_preds), 0)
final_prediction = [classes[label.item()] for label in torch.argmax(mean_pred, 1)]
# Generating submission file
filename = str(e) + args.save_path
print("Generating submission file {} ...".format(filename))
submission = list(zip(indexes, final_prediction))
submission.sort(key=lambda t: t[0])
lines = ["id,label\n"] + ["{},{}\n".format(idx, label) for idx, label in submission]
with open(filename, "w") as f:
f.writelines(lines)
print("Done")
print("Done generating predictions")
print("Final prediction is in the file", filename)