-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_maxsat_instances.py
173 lines (150 loc) · 7.08 KB
/
generate_maxsat_instances.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import argparse
import copy
import math
from pathlib import Path
from typing import Counter, Dict, List, Optional, Sequence, Tuple
import chainer
import numpy as np
import bnn
import datasets
import encoder
import visualize
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, choices=['mnist', 'mnist_back_image', 'mnist_rot'], default='mnist', help='dataset name')
parser.add_argument('--model', type=str, default=None, help='model file (*.npz)')
parser.add_argument('-o', '--output-dir', type=str, default="instances/maxsat", help='output directory')
parser.add_argument('--format', type=str, choices=["wbo", "wcnf"], help='file format')
parser.add_argument('--norm', type=str, choices=['0', '1', '2', 'inf'], nargs='*', default=['0', '1', '2', 'inf'], help='norm to minimize')
parser.add_argument('--card', type=str, choices=["sequential", "parallel", "totalizer"], default="parallel", help='encoding of cardinality constraints')
parser.add_argument('--target', type=str, default="adversarial", choices=['adversarial', 'truelabel'], help='target label')
parser.add_argument('--instance-no', type=int, nargs='*', default=[], help='specify instance number')
parser.add_argument('--instances-per-class', type=int, default=None, help='number of instances to generate per class')
parser.add_argument('--debug-sat', action='store_true', help='produce CNF or OPB for debug')
parser.add_argument('--ratio', type=float, nargs='*', default=[1.0], help='restrict search space to most salient pixels')
args = parser.parse_args()
result_dir = Path(args.output_dir)
result_dir.mkdir(parents=True, exist_ok=True)
train, test = datasets.get_dataset(args.dataset)
if args.model is None:
weights_filename = f"models/{args.dataset}.npz"
else:
weights_filename = args.model
neurons = [28 * 28, 200, 100, 100, 100, 10]
model = bnn.BNN(neurons, stochastic_activation=True)
chainer.serializers.load_npz(weights_filename, model)
orig_image_scaled = test._datasets[0]
orig_image_scaled = orig_image_scaled[:1000]
orig_image = np.round(orig_image_scaled * 255).astype(np.uint8)
with chainer.using_config("train", False), chainer.using_config("enable_backprop", False):
orig_image_bin = bnn.bin(model.input_bn(orig_image_scaled)).array > 0
orig_logits = model(orig_image_scaled).array
predicated_label = np.argmax(orig_logits, axis=1)
if args.format == "wbo":
enc_base = encoder.BNNEncoder(cnf=False)
elif args.format == "wcnf":
enc_base = encoder.BNNEncoder(cnf=True, counter=args.card)
else:
raise RuntimeError("unknown ext: " + args.format)
inputs = enc_base.new_vars(784)
outputs = enc_base.new_vars(10)
enc_base.encode_bin_input(model, inputs, outputs)
counter = Counter[int]()
for instance_no, (x, true_label) in enumerate(test):
if len(args.instance_no) > 0 and instance_no not in args.instance_no:
continue
if args.instances_per_class is not None and counter[true_label] >= args.instances_per_class:
continue
print(f"dataset={args.dataset}; instance={instance_no}; true_label={true_label} predicted_label={predicated_label[instance_no]}")
if predicated_label[instance_no] != true_label:
continue
counter[true_label] += 1
img = visualize.to_image(args.dataset, orig_image[instance_no])
fname = result_dir / f"bnn_{args.dataset}_{instance_no}_label{true_label}.png"
if not fname.exists():
img.save(fname)
with chainer.using_config("train", False), chainer.using_config("enable_backprop", True):
saliency_map = model.saliency_map(x, true_label)
enc = copy.copy(enc_base)
if args.target == "truelabel":
enc.add_clause([outputs[true_label]])
else:
enc.add_clause([-outputs[true_label]])
input_bn = model.input_bn
mu = input_bn.avg_mean
sigma = np.sqrt(input_bn.avg_var + input_bn.eps)
gamma = input_bn.gamma.array
beta = input_bn.beta.array
numerically_unstable = False
mod: List[Tuple[encoder.Lit, Optional[int]]] = []
for j, pixel in enumerate(orig_image[instance_no]):
# C_frac = 255 * (- beta[j] * sigma[j] / gamma[j] + mu[j])
C_frac = (- beta[j] * sigma[j] / gamma[j] + mu[j]) / np.float32(1 / 255.0)
if gamma[j] >= 0:
# x ≥ ⌈255 (- βσ/γ + μ)⌉ = C
C = int(math.ceil(C_frac))
if orig_image_bin[instance_no, j] != (pixel >= C):
numerically_unstable = True
break
# assert orig_image_bin[instance_no, j] == (pixel >= C)
if pixel < C:
if C <= 255:
mod.append((- inputs[j], C - pixel))
else:
mod.append((- inputs[j], None)) # impossible to change
else:
if C - 1 >= 0:
mod.append((inputs[j], (C - 1) - pixel))
else:
mod.append((inputs[j], None)) # impossible to change
else:
# x ≤ ⌊255 (- βσ/γ + μ)⌋ = C
C = int(math.floor(C_frac))
# assert orig_image_bin[instance_no, j] == (pixel <= C)
if orig_image_bin[instance_no, j] != (pixel <= C):
numerically_unstable = True
break
if pixel > C:
if C >= 0:
mod.append((- inputs[j], C - pixel))
else:
mod.append((- inputs[j], None)) # impossible to change
else:
if C + 1 <= 255:
mod.append((inputs[j], (C + 1) - pixel))
else:
mod.append((inputs[j], None)) # impossible to change
if numerically_unstable:
print("numerically unstable")
continue
# debug
if args.debug_sat:
enc2 = copy.copy(enc)
for lit, w in mod:
enc2.add_clause([lit])
if args.format == "wcnf":
fname = result_dir / f"bnn_{args.dataset}_{instance_no}_label{true_label}_{args.target}_{args.card}_debug.cnf"
elif args.format == "wbo":
fname = result_dir / f"bnn_{args.dataset}_{instance_no}_label{true_label}_{args.target}_debug.opb"
else:
raise RuntimeError("unknown ext: " + args.format)
enc2.write_to_file(fname)
for ratio in args.ratio:
if ratio == 1.0:
mod2 = mod
ratio_str = ""
else:
ratio_str = f"{int(ratio * 100)}p"
important_pixels = set(list(reversed(np.argsort(np.abs(saliency_map))))[:int(len(saliency_map) * ratio)])
important_variables = set(inputs[instance_no] for i in important_pixels)
mod2 = [(lit, w if abs(lit) in important_variables else None) for lit, w in mod]
for norm in args.norm:
xs: List[str] = [
"bnn", args.dataset, str(instance_no), f"label{true_label}",
args.target, "norm_" + str(norm), ratio_str
]
if args.format == "wcnf":
xs.append(args.card)
fname = result_dir / ('_'.join([s for s in xs if len(s) > 0]) + "." + args.format)
enc2 = copy.copy(enc)
enc2.add_norm_cost(norm, mod2)
enc2.write_to_file_opt(fname)