-
Notifications
You must be signed in to change notification settings - Fork 9
/
torch_attack.py
272 lines (222 loc) · 8.52 KB
/
torch_attack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
# encoding:utf-8
"""Implementation of sample attack."""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms as T
import torch.nn.functional as F
from torch.autograd import Variable as V
# from torch.autograd.gradcheck import zero_gradients
from torch.utils import data
import os
import random
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
from torch_nets import (
tf2torch_inception_v3,
tf2torch_inception_v4,
tf2torch_resnet_v2_50,
tf2torch_resnet_v2_101,
tf2torch_resnet_v2_152,
tf2torch_inc_res_v2,
tf2torch_adv_inception_v3,
tf2torch_ens3_adv_inc_v3,
tf2torch_ens4_adv_inc_v3,
tf2torch_ens_adv_inc_res_v2,
)
list_nets = [
'tf2torch_inception_v3',
'tf2torch_inception_v4',
'tf2torch_resnet_v2_50',
'tf2torch_resnet_v2_101',
'tf2torch_resnet_v2_152',
'tf2torch_inc_res_v2',
'tf2torch_adv_inception_v3',
'tf2torch_ens3_adv_inc_v3',
'tf2torch_ens4_adv_inc_v3',
'tf2torch_ens_adv_inc_res_v2'
]
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=str, default='0', help='The ID of GPU to use.')
parser.add_argument('--input_csv', type=str, default='dataset/dev_dataset.csv', help='Input csv with images.')
parser.add_argument('--input_dir', type=str, default='dataset/images/', help='Input images.')
parser.add_argument('--output_dir', type=str, default='adv_img_torch/', help='Output directory with adv images.')
parser.add_argument('--model_dir', type=str, default='torch_nets_weight/', help='Model weight directory.')
parser.add_argument('--white_model', type=str, default='tf2torch_inception_v3', help='Substitution model.')
parser.add_argument("--max_epsilon", type=float, default=16.0, help="Maximum size of adversarial perturbation.")
parser.add_argument("--num_iter", type=int, default=10, help="Number of iterations.")
parser.add_argument("--batch_size", type=int, default=10, help="How many images process at one time.")
parser.add_argument("--momentum", type=float, default=1.0, help="Momentum")
opt = parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
def seed_torch(seed):
"""Set a random seed to ensure that the results are reproducible"""
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False
def mkdir(path):
"""Check if the folder exists, if it does not exist, create it"""
isExists = os.path.exists(path)
if not isExists:
os.makedirs(path)
class Normalize(nn.Module):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
"""
(input - mean) / std
ImageNet normalize:
'tensorflow': mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
'torch': mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
"""
super(Normalize, self).__init__()
self.mean = mean
self.std = std
def forward(self, input):
size = input.size()
x = input.clone()
for i in range(size[1]):
x[:, i] = (x[:, i] - self.mean[i]) / self.std[i]
return x
class ImageNet(data.Dataset):
"""load data from img and csv"""
def __init__(self, dir, csv_path, transforms=None):
self.dir = dir
self.csv = pd.read_csv(csv_path)
self.transforms = transforms
def __getitem__(self, index):
img_obj = self.csv.loc[index]
ImageID = img_obj['ImageId'] + '.png'
Truelabel = img_obj['TrueLabel']
img_path = os.path.join(self.dir, ImageID)
pil_img = Image.open(img_path).convert('RGB')
if self.transforms:
data = self.transforms(pil_img)
else:
data = pil_img
return data, ImageID, Truelabel
def __len__(self):
return len(self.csv)
def get_model(net_name, model_dir):
"""Load converted model"""
model_path = os.path.join(model_dir, net_name + '.npy')
if net_name == 'tf2torch_inception_v3':
net = tf2torch_inception_v3
elif net_name == 'tf2torch_inception_v4':
net = tf2torch_inception_v4
elif net_name == 'tf2torch_resnet_v2_50':
net = tf2torch_resnet_v2_50
elif net_name == 'tf2torch_resnet_v2_101':
net = tf2torch_resnet_v2_101
elif net_name == 'tf2torch_resnet_v2_152':
net = tf2torch_resnet_v2_152
elif net_name == 'tf2torch_inc_res_v2':
net = tf2torch_inc_res_v2
elif net_name == 'tf2torch_adv_inception_v3':
net = tf2torch_adv_inception_v3
elif net_name == 'tf2torch_ens3_adv_inc_v3':
net = tf2torch_ens3_adv_inc_v3
elif net_name == 'tf2torch_ens4_adv_inc_v3':
net = tf2torch_ens4_adv_inc_v3
elif net_name == 'tf2torch_ens_adv_inc_res_v2':
net = tf2torch_ens_adv_inc_res_v2
else:
print('Wrong model name:', net_name, '!')
exit()
if 'inc' in net_name:
model = nn.Sequential(
# Images for inception classifier are normalized to be in [-1, 1] interval.
Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]),
net.KitModel(model_path, aux_logits=True).eval().cuda(),)
else:
model = nn.Sequential(
# Images for inception classifier are normalized to be in [-1, 1] interval.
Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]),
net.KitModel(model_path).eval().cuda(),)
return model
def get_models(list_nets, model_dir):
"""load models with dict"""
nets = {}
for net in list_nets:
nets[net] = get_model(net, model_dir)
return nets
def save_img(images, filenames, output_dir):
"""save high quality jpeg"""
mkdir(output_dir)
for i, filename in enumerate(filenames):
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = images[i].mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
img = Image.fromarray(ndarr)
img.save(os.path.join(output_dir, filename))
def attack(model, img, label):
"""generate adversarial images"""
eps = opt.max_epsilon / 255.0
num_iter = opt.num_iter
alpha = eps / num_iter
momentum = opt.momentum
noise = torch.zeros_like(img, requires_grad=True)
old_grad = 0.0
for i in range(num_iter):
# zero_gradients(noise)
x = img + noise
output = model(x)
loss = F.cross_entropy(output[0], label) # logit
loss += F.cross_entropy(output[1], label) # aux_logit
loss.backward()
grad = noise.grad.data
# MI-FGSM
# grad = grad / torch.abs(grad).mean([1,2,3], keepdim=True)
# grad = momentum * old_grad + grad
# old_grad = grad
noise = noise + alpha * torch.sign(grad)
# Avoid out of bound
noise = torch.clamp(noise, -eps, eps)
x = img + noise
x = torch.clamp(x, 0.0, 1.0)
noise = x - img
noise = V(noise, requires_grad=True)
adv = img + noise.detach()
return adv
def main():
transforms = T.Compose([T.ToTensor()])
# Load inputs
inputs = ImageNet(opt.input_dir, opt.input_csv, transforms)
data_loader = DataLoader(inputs, batch_size=opt.batch_size, shuffle=False, pin_memory=True, num_workers=8)
input_num = len(inputs)
# Create models
models = get_models(list_nets, opt.model_dir)
# Initialization parameters
correct_num = {}
logits = {}
for net in list_nets:
correct_num[net] = 0
# Start iteration
for images, filename, label in tqdm(data_loader):
label = label.cuda()
images = images.cuda()
# Start Attack
adv_img = attack(models[opt.white_model], images, label)
# Save adversarial examples
save_img(adv_img, filename, opt.output_dir)
# Prediction
with torch.no_grad():
for net in list_nets:
if "inc" in net:
logits[net] = models[net](adv_img)[0]
else:
logits[net] = models[net](adv_img)
correct_num[net] += (torch.argmax(logits[net], axis=1) != label).detach().sum().cpu()
# Print attack success rate
for net in list_nets:
print('{} attack success rate: {:.2%}'.format(net, correct_num[net]/input_num))
if __name__ == '__main__':
seed_torch(0)
main()