-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
167 lines (142 loc) · 6.23 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from lib.ModelWrapper import ModelWrapper
from tensorboardX import SummaryWriter
import torch
from torchvision import transforms, datasets
import numpy as np
import random
import sys
import os
args = sys.argv
data_name = args[1] # 'svhn', 'cifar10', 'cifar100'
model_name = args[2] # 'resnet18', 'resnet34', 'vgg16', 'vgg13', 'vgg11'
noise_split = float(args[3])
opt = args[4]
lr = float(args[5])
test_id = int(args[6])
data_root = args[7]
# setting
train_batch_size = 128
train_epoch = 250
eval_batch_size = 250
data_root = os.path.join(data_root, data_name)
if data_name == 'cifar10':
dataset = datasets.CIFAR10
nb_class = 10
from archs.cifar10 import vgg, resnet
elif data_name == 'cifar100':
dataset = datasets.CIFAR100
nb_class = 100
from archs.cifar100 import vgg, resnet
elif data_name == 'svhn':
dataset = datasets.SVHN
nb_class = 10
from archs.svhn import vgg, resnet
else:
raise Exception('No such dataset')
if model_name == 'vgg11':
model = vgg.vgg11_bn()
elif model_name == 'vgg13':
model = vgg.vgg13_bn()
elif model_name == 'vgg16':
model = vgg.vgg16_bn()
elif model_name == 'resnet18':
model = resnet.resnet18()
elif model_name == 'resnet34':
model = resnet.resnet34()
else:
raise Exception("No such model!")
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
eval_transform = transforms.Compose([transforms.ToTensor()])
# load data
if 'cifar' in data_name:
train_data = dataset(data_root, train=True, transform=train_transform, download=True)
if noise_split > 0:
train_targets = np.array(train_data.targets)
data_size = len(train_targets)
random_index = random.sample(range(data_size), int(data_size*noise_split))
random_part = train_targets[random_index]
np.random.shuffle(random_part)
train_targets[random_index] = random_part
train_data.targets = train_targets.tolist()
noise_data = dataset(data_root, train=False, transform=eval_transform, download=True)
noise_data.targets = random_part.tolist()
noise_data.data = train_data.data[random_index]
test_data = dataset(data_root, train=False, transform=eval_transform)
var_data = dataset(data_root, train=True, transform=eval_transform, download=True)
elif 'svhn' in data_name:
train_data = dataset(data_root, split='train', transform=train_transform, download=True)
if noise_split > 0:
train_targets = np.array(train_data.labels)
data_size = len(train_targets)
random_index = random.sample(range(data_size), int(data_size * noise_split))
random_part = train_targets[random_index]
np.random.shuffle(random_part)
train_targets[random_index] = random_part
train_data.labels = train_targets.tolist()
noise_data = dataset(data_root, split='test', transform=eval_transform, download=True)
noise_data.labels = random_part.tolist()
noise_data.data = train_data.data[random_index]
test_data = dataset(data_root, split='test', transform=eval_transform)
var_data = dataset(data_root, split='train', transform=eval_transform, download=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=train_batch_size, shuffle=True, num_workers=0,
drop_last=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=eval_batch_size, shuffle=False, num_workers=0,
drop_last=False)
var_loader = torch.utils.data.DataLoader(var_data, batch_size=train_batch_size, shuffle=False, num_workers=0,
drop_last=False)
if noise_split > 0:
noise_loader = torch.utils.data.DataLoader(noise_data, batch_size=eval_batch_size, shuffle=True, num_workers=0,
drop_last=False)
# build model
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
if opt == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
elif opt == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
wrapper = ModelWrapper(model, optimizer, criterion, device)
# train the model
save_path = os.path.join('runs', 'noise_{}_opt_{}_lr_{}'.format(noise_split, opt, lr),
data_name, "{}".format(model_name), "{}".format(test_id))
if not os.path.exists(save_path):
os.makedirs(save_path)
writer = SummaryWriter(log_dir=os.path.join(save_path, "log"), flush_secs=30)
wrapper.train()
for id_epoch in range(train_epoch):
# train loop
train_loss = 0
train_acc = 0
train_size = 0
for id_batch, (inputs, targets) in enumerate(train_loader):
loss, acc, correct, _, _ = wrapper.train_on_batch_with_gradients_recorded(inputs, targets)
train_loss += loss
train_acc += correct
train_size += len(targets)
print("epoch:{}/{}, batch:{}/{}, loss={}, acc={}".
format(id_epoch + 1, train_epoch, id_batch + 1, len(train_loader), loss, acc))
# recorder loss and acc
train_loss /= id_batch + 1
train_acc /= train_size
writer.add_scalar("train acc", train_acc, id_epoch+1)
writer.add_scalar("train loss", train_loss, id_epoch+1)
# recorder output var
wrapper.eval()
optimization_var = wrapper.get_optimization_var(var_loader)
writer.add_scalar("optimization var", optimization_var, id_epoch+1)
# eval
wrapper.eval()
test_loss, test_acc = wrapper.eval_all(test_loader)
print("epoch:{}/{}, batch:{}/{}, testing...".format(id_epoch + 1, train_epoch, id_batch + 1, len(train_loader)))
print("clean: loss={}, acc={}".format(test_loss, test_acc))
writer.add_scalar("test acc", test_acc, id_epoch+1)
writer.add_scalar("test loss", test_loss, id_epoch+1)
if noise_split > 0:
noise_loss, noise_acc = wrapper.eval_all(noise_loader)
print("noise: loss={}, acc={}".format(noise_loss, noise_acc))
writer.add_scalar("noise acc", noise_acc, id_epoch+1)
writer.add_scalar("noise loss", noise_loss, id_epoch+1)
wrapper.train()
writer.close()