forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathverify.py
473 lines (404 loc) · 19.7 KB
/
verify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
import torch
import torch.jit
import torch.onnx
import onnx
import onnx.helper
import numpy as np
import difflib
import contextlib
import io
def colonize(msg, sep=": "):
if not msg:
return ""
else:
return msg + sep
class Errors(object):
"""
An error-collecting object which supports error recovery.
It is intended to be used like a context manager:
>>> with Errors("Top-level error message") as errs:
>>> ...
"""
def __init__(self, msg, rtol=1e-3, atol=1e-5):
self.msg = msg
self.errors = []
self.context = []
self.rtol = rtol
self.atol = atol
# Allocated upon instance creation so that multiple Errors
# can be used
class ShortCircuit(Exception):
pass
self.exc_class = ShortCircuit
def requireAlmostEqual(self, x, y, msg=None):
"""
Test that x and y are nearly equal (equal within self.rtol
precision); aborts execution if they are not.
"""
self.almostEqualAndThen(x, y, msg, self.failWith)
def checkAlmostEqual(self, x, y, msg=None):
"""
Test that x and y are nearly equal (equal within self.rtol
precision), but continue execution even if they are not equal.
To prevent error cascades, you should remember to call 'failIfErrs'
at some later point in time.
"""
self.almostEqualAndThen(x, y, msg, self.addErr)
def almostEqualAndThen(self, x, y, msg, k):
"""
Helper for implementing 'requireAlmostEqual' and 'checkAlmostEqual'.
Upon failure, invokes continuation 'k' with the error message.
At the moment, only tests on 'numpy.ndarray' are supported.
"""
if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
try:
np.testing.assert_allclose(x, y, rtol=self.rtol, atol=self.atol,
equal_nan=True, verbose=True)
except AssertionError as e:
raise
k("{}{}".format(colonize(msg), str(e).lstrip()))
else:
raise RuntimeError("Unsupported almost equal test")
def requireEqual(self, x, y, msg=None):
"""
Test that x and y are equal; aborts execution if they are not.
"""
self.equalAndThen(x, y, msg, self.failWith)
def checkEqual(self, x, y, msg=None):
"""
Test that x and y are equal, but continue execution even if they are not equal.
To prevent error cascades, you should remember to call 'failIfErrs'
at some later point in time.
"""
self.equalAndThen(x, y, msg, self.addErr)
# Bit-for-bit accuracy test
def equalAndThen(self, x, y, msg, k):
"""
Helper for implementing 'requireEqual' and 'checkEqual'. Upon failure,
invokes continuation 'k' with the error message.
"""
if isinstance(x, onnx.TensorProto) and isinstance(y, onnx.TensorProto):
self.equalAndThen(x.name, y.name, msg, k)
# Use numpy for the comparison
t1 = onnx.numpy_helper.to_array(x)
t2 = onnx.numpy_helper.to_array(y)
new_msg = "{}In embedded parameter '{}'".format(colonize(msg), x.name)
self.equalAndThen(t1, t2, new_msg, k)
elif isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
try:
np.testing.assert_equal(x, y)
except AssertionError as e:
raise
k("{}{}".format(colonize(msg, ": "), str(e).lstrip()))
else:
if x != y:
# TODO: Better algorithm for lists
sx = str(x)
sy = str(y)
if len(sx) > 40 or len(sy) > 40 or '\n' in sx or '\n' in sy:
# long form
l = "=" * 50
k("\n{}The value\n{}\n{}\n{}\n\ndoes not equal\n\n{}\n{}\n{}"
.format(colonize(msg, ":\n"), l, sx, l, l, sy, l))
else:
k("{}{} != {}".format(colonize(msg), sx, sy))
def requireMultiLineEqual(self, x, y, msg=None):
"""
Test that long, multi-line strings x and y are equal;
aborts execution if they are not.
"""
self.multiLineEqualAndThen(x, y, msg, self.failWith)
def multiLineEqualAndThen(self, x, y, msg, k):
"""
Helper for implementing 'requireMultiLineEqual'. Upon failure,
invokes continuation 'k' with the error message.
"""
if msg is None:
msg = "Strings are not equal"
if x != y:
diff = difflib.ndiff(x.splitlines(True), y.splitlines(True))
k("{}{}".format(colonize(msg, ":\n\n"), "".join(diff)))
def addErr(self, msg):
"""
Add an error to the error context, but continue executing.
"""
# TODO: instead of immediately concatenating the context in the msg,
# attach it as metadata and make a decision how to format it later.
msg_w_ctx = msg
for c in reversed(self.context):
msg += "\n\n * " + "\n ".join(c.splitlines())
self.errors.append(msg)
def fail(self):
"""
Immediately fail and short-circuit to the next recovery context.
NB: It is an error to 'fail' without having added any errors to
the error context.
"""
raise self.exc_class()
def failWith(self, msg):
"""
Add an error to the error context, and then short-circuit.
"""
self.addErr(msg)
self.fail()
def failIfErrs(self):
"""
If there are any errors in the error context, short-circuit.
This is used to prevent error cascades.
"""
if self.errors:
self.fail()
def recover(self):
"""
Returns a context manager which can be used to recover in case of
an error. Example usage:
>>> with errs.recover():
>>> ...
"""
parent_self = self
class Recover(object):
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
if exc_type == parent_self.exc_class:
return True
return Recover()
def addErrCtxt(self, msg):
"""
Returns a context manager which encloses a fragment of code with
an extra contextual message, e.g., where an error occurred, or a hint
applicable to all errors in the area. Example usage:
>>> with errs.addErrCtx("Some text"):
>>> ...
"""
parent_self = self
class AddContext(object):
def __enter__(self):
parent_self.context.append(msg)
def __exit__(self, exc_type, exc_value, traceback):
parent_self.context.pop()
return AddContext()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.errors:
errors_msg = "\n\n".join(map(lambda x: "ERROR: " + x, self.errors))
final_msg = "{}\n{}\n{}".format(self.msg, '-' * 70, errors_msg)
raise AssertionError(final_msg)
if exc_type == self.exc_class:
raise RuntimeError("ShortCircuit was raised, but no errors were recorded")
@contextlib.contextmanager
def set_training(model, mode):
"""
A context manager to temporarily set the training mode of 'model'
to 'mode', resetting it when we exit the with-block.
"""
old_mode = model.training
if old_mode != mode:
model.train(mode)
try:
yield
finally:
if old_mode != mode:
model.train(old_mode)
def verify(model, args, backend, verbose=False, training=False, rtol=1e-3, atol=1e-7,
test_args=2, do_constant_folding=True, example_outputs=None, opset_version=None,
keep_initializers_as_inputs=True, add_node_names=False):
"""
Export a model into ONNX, import it into a specified ONNX backend, and then
on a few random inputs verify that PyTorch and the backend produced the same
results. Requires onnx to be installed.
This function may spuriously fail: some operators are implemented with
different numerical precision in an ONNX backend, in which case an unstable
network (e.g., Inception) may blow up these numerical instabilities. This
situation is less likely to happen if your model has been trained. However,
if this is not the case, you may have found a bug! Please report it to the
PyTorch developers. You can also debug the issue yourself by removing
suffixes of operators from your model until verification passes.
For reproducibility, we recommend explicitly setting PyTorch's seed before
invoking this function.
Arguments:
model (torch.nn.Module): the model to be exported and verified
args (tuple of arguments): the inputs to
the model, e.g., such that ``model(*args)`` is a valid
invocation of the model. Any non-Variable arguments will
be hard-coded into the exported model; any Variable arguments
will become inputs of the exported model, in the order they
occur in args. If args is a Variable, this is equivalent
to having called it with a 1-ary tuple of that Variable.
(Note: passing keyword arguments to the model is not currently
supported. Give us a shout if you need it.)
backend (onnx.backend module): ONNX backend to verify with
verbose (bool, default False): if specified, we will print out a debug
description of the trace being exported.
training (bool, default False): export the model in training mode. At
the moment, ONNX is oriented towards exporting models for inference
only, so you will generally not need to set this to True.
rtol (float, default 1e-3): relative precision required
test_args (int or iterable of args, default 2):
either an integer specifying the number
of random arguments to generate, or an iterable producing arguments
to test under.
opset_version (int, default None): the opset version of the model to
export. If not specified, the default value in symboli_helper will
be used in utils._export().
"""
def _nested_map(condition, fn, condition_msg=None):
def _map(obj):
if condition(obj):
return fn(obj)
elif obj is None:
return None
elif isinstance(obj, (list, tuple)):
return type(obj)(_map(x) for x in obj)
else:
raise ValueError("Auto nesting doesn't know how to process "
"an input object of type " + torch.typename(obj) +
(". Accepted types: " + condition_msg +
", or lists/tuples of them"
if condition_msg else ""))
return _map
def _iter_filter(condition, allow_unknown=False, condition_msg=None):
def _iter(obj):
if condition(obj):
yield obj
elif obj is None:
return
elif isinstance(obj, (list, tuple)):
for o in obj:
for var in _iter(o):
yield var
elif allow_unknown:
yield obj
else:
raise ValueError("Auto nesting doesn't know how to process "
"an input object of type " + torch.typename(obj) +
(". Accepted types: " + condition_msg +
", or lists/tuples of them"
if condition_msg else ""))
return _iter
def is_tensor(o):
return isinstance(o, torch.Tensor)
_iter_tensors = _iter_filter(is_tensor, condition_msg="Tensors")
def randomize_arg(arg):
new_data = arg.data.clone()
# For now, don't try randomizing non-float tensors; these
# are likely to be things like indices, where just randomly
# spattering some longs is unlikely to work. One way we could
# make this work is to apply a random permutation or something.
if arg.is_floating_point():
new_data.uniform_()
return torch.autograd.Variable(new_data, requires_grad=arg.requires_grad)
randomize_args = _nested_map(is_tensor, randomize_arg)
def backend_args(args):
# TODO: onnx should accept iterables
return tuple(v.data.cpu().numpy() for v in _iter_tensors(args))
def load_bytes(b):
b.seek(0)
x = onnx.load(b)
# doc_string has stack traces - let's remove them to make comparison
# sane
onnx.helper.strip_doc_string(x)
return x
# Special case for common case of passing a single Tensor
if isinstance(args, torch.Tensor):
args = (args,)
with set_training(model, training):
proto_bytes = io.BytesIO()
torch_out = torch.onnx._export(model, args, proto_bytes, verbose=verbose,
do_constant_folding=do_constant_folding,
example_outputs=example_outputs,
opset_version=opset_version,
keep_initializers_as_inputs=keep_initializers_as_inputs,
add_node_names=add_node_names)
if isinstance(model, torch.jit.ScriptModule):
torch_out = model(*args)
proto = load_bytes(proto_bytes)
prepared = backend.prepare(proto)
def run(args):
alt_proto_bytes = io.BytesIO()
torch_out = torch.onnx._export(model, args, alt_proto_bytes, verbose=verbose,
do_constant_folding=do_constant_folding,
example_outputs=example_outputs,
opset_version=opset_version,
keep_initializers_as_inputs=keep_initializers_as_inputs,
add_node_names=add_node_names)
if isinstance(model, torch.jit.ScriptModule):
torch_out = model(*args)
alt_proto = load_bytes(alt_proto_bytes)
if proto.SerializeToString() != alt_proto.SerializeToString():
# OK, let's try to figure out what happened.
msg = "When I exported your model with different inputs, the result was different."
if not verbose:
msg += "\n(To get more information, run torch.onnx.verify(..., verbose=True))"
with Errors(msg, rtol=rtol, atol=atol) as errs:
# First, check if we have the same number of parameters, and
# that they're the same order. If they don't, something has *really* gone wrong.
initializer_order_hint = ("This is really strange! The second time I exported your model,\n"
"it had a different set of parameters. Are you assigning Parameters\n"
"in the forward() of your model definition?")
with errs.addErrCtxt(initializer_order_hint):
errs.requireEqual(list(map(lambda x: x.name, proto.graph.initializer)),
list(map(lambda x: x.name, alt_proto.graph.initializer)),
msg="Parameters list differs")
# Now check if the embedded parameters are actually the same
initializer_hint = ("A difference in embedded parameters usually means that\n"
"your model is updating parameters/buffers even in inference\n"
"mode. Look for a buggy nn.Module which isn't respecting train().\n")
with errs.recover(), errs.addErrCtxt(initializer_hint):
for x, y in zip(proto.graph.initializer, alt_proto.graph.initializer):
errs.checkEqual(x, y)
# Next, check if the model structure lines up.
structure_hint = ("A difference in model structure usually means that\n"
"your model has dynamic control flow. These models are not\n"
"currently supported by the exporter.")
with errs.recover(), errs.addErrCtxt(structure_hint):
# Delete initializers since we already tested them
stripped_proto = onnx.ModelProto()
stripped_proto.CopyFrom(proto)
del stripped_proto.graph.initializer[:]
stripped_alt_proto = onnx.ModelProto()
stripped_alt_proto.CopyFrom(alt_proto)
del stripped_alt_proto.graph.initializer[:]
# Compare the printable graph representations first
errs.requireMultiLineEqual(onnx.helper.printable_graph(stripped_proto.graph),
onnx.helper.printable_graph(stripped_alt_proto.graph))
# Compare the actual protobuf text formats now (not
# very user-friendly!)
errs.requireMultiLineEqual(str(stripped_proto), str(stripped_alt_proto))
# One last ditch effort, using built-in equality on
# protobufs
errs.requireEqual(stripped_proto, stripped_alt_proto)
errs.failIfErrs()
# At this point, we should have figured out why the binary
# protobufs differed, and short-circuited out of this code
# with a helpful error message. But what if we didn't?
# We better still try to give a good error message in this
# case. We EXPECT these requires to fail. If they don't,
# that is a bug in verify
errs.requireEqual(proto, alt_proto)
errs.requireEqual(proto_bytes.getvalue(), alt_proto_bytes.getvalue())
raise AssertionError()
# TODO: test that the traced model also returns the same thing...
run_helper(torch_out, args)
# Factored out so we can avoid one run of the model
def run_helper(torch_out, args):
backend_out = prepared.run(backend_args(args))
if isinstance(torch_out, torch.Tensor):
torch_out = (torch_out,)
torch_out, _ = torch._C._jit_flatten(torch_out)
# NB: onnx backend NEVER returns bare numpy array
msg = "ONNX backend returned different results from PyTorch"
result_hint = ("If you are not using trained parameters, a difference in results\n"
"could mean that your network is numerically unstable. Otherwise\n"
"it indicates a bug in PyTorch/ONNX; please file a bug report.")
with Errors(msg, rtol=rtol, atol=atol) as errs, errs.addErrCtxt(result_hint):
for i, (x, y) in enumerate(zip(torch_out, backend_out)):
errs.checkAlmostEqual(x.data.cpu().numpy(), y, "In output {}".format(i))
run_helper(torch_out, args)
if isinstance(test_args, int):
for i in range(test_args):
run(randomize_args(args))
else:
for test_arg in test_args:
run(test_arg)