forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_pytorch_onnx_caffe2_quantized.py
208 lines (167 loc) · 8.05 KB
/
test_pytorch_onnx_caffe2_quantized.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
from __future__ import print_function
import numpy as np
import unittest
import torch.onnx
import io
import onnx
import caffe2.python.onnx.backend as c2
class TestQuantizedOps(unittest.TestCase):
def generic_test(self, model, sample_inputs, input_names=None):
torch.backends.quantized.engine = "qnnpack"
pt_inputs = tuple(torch.from_numpy(x) for x in sample_inputs)
model.qconfig = torch.quantization.default_qconfig
q_model = torch.quantization.prepare(model, inplace=False)
q_model = torch.quantization.convert(q_model, inplace=False)
pytorch_res = q_model(*pt_inputs)
f = io.BytesIO()
torch.onnx.export(q_model, pt_inputs, f, input_names=input_names,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
f.seek(0)
onnx_model = onnx.load(f)
caffe_res = c2.run_model(onnx_model, dict(zip(input_names, sample_inputs)))[0]
np.testing.assert_almost_equal(pytorch_res.numpy(), caffe_res, decimal=3)
def generic_unary_test(self, op):
class QModule(torch.nn.Module):
def __init__(self, op):
super(QModule, self).__init__()
self.quant1 = torch.quantization.QuantStub()
self.op = op
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
res = self.op(self.quant1(x))
return self.dequant(res)
x = np.random.random((1, 2)).astype("float32")
self.generic_test(QModule(op), (x,), input_names=["x"])
def test_quantized_add(self):
class QAddModule(torch.nn.Module):
def __init__(self):
super(QAddModule, self).__init__()
self.quant1 = torch.quantization.QuantStub()
self.quant2 = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x, y):
res = torch.ops.quantized.add(self.quant1(x), self.quant2(y), 1.0, 0)
return self.dequant(res)
x = np.random.random(2).astype("float32")
y = np.random.random(2).astype("float32")
self.generic_test(QAddModule(), (x, y), input_names=["x", "y"])
def test_quantized_relu(self):
self.generic_unary_test(torch.nn.ReLU())
def export_to_onnx(self, model, input, input_names):
outputs = model(input)
traced = torch.jit.trace(model, input)
buf = io.BytesIO()
torch.jit.save(traced, buf)
buf.seek(0)
model = torch.jit.load(buf)
f = io.BytesIO()
torch.onnx.export(model, input, f, input_names=input_names, example_outputs=outputs,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
f.seek(0)
onnx_model = onnx.load(f)
return onnx_model
def test_qlinear_model(self):
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.qconfig = torch.quantization.default_qconfig
self.fc1 = torch.quantization.QuantWrapper(torch.nn.Linear(5, 10).to(dtype=torch.float))
def forward(self, x):
x = self.fc1(x)
return x
torch.backends.quantized.engine = "qnnpack"
qconfig = torch.quantization.default_qconfig
model = LinearModel()
model.qconfig = qconfig
model = torch.quantization.prepare(model)
model = torch.quantization.convert(model)
x_numpy = np.random.rand(1, 2, 5).astype(np.float32)
x = torch.from_numpy(x_numpy).to(dtype=torch.float)
outputs = model(x)
input_names = ["x"]
onnx_model = self.export_to_onnx(model, x, input_names)
caffe_res = c2.run_model(onnx_model, dict(zip(input_names, x_numpy)))[0]
np.testing.assert_almost_equal(np.squeeze(outputs.numpy()), caffe_res, decimal=3)
def test_qconv_model(self):
class ConvModel(torch.nn.Module):
def __init__(self):
super(ConvModel, self).__init__()
self.qconfig = torch.quantization.default_qconfig
self.fc1 = torch.quantization.QuantWrapper(torch.nn.Conv2d(3, 5, 2, bias=True).to(dtype=torch.float))
def forward(self, x):
x = self.fc1(x)
return x
torch.backends.quantized.engine = "qnnpack"
qconfig = torch.quantization.default_qconfig
model = ConvModel()
model.qconfig = qconfig
model = torch.quantization.prepare(model)
model = torch.quantization.convert(model)
x_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32)
x = torch.from_numpy(x_numpy).to(dtype=torch.float)
outputs = model(x)
input_names = ["x"]
onnx_model = self.export_to_onnx(model, x, input_names)
y = np.expand_dims(x_numpy, axis=0)
caffe_res = c2.run_model(onnx_model, dict(zip(input_names, y)))[0]
# Permute pytorch output to NHWC
np.testing.assert_almost_equal(outputs.numpy(), caffe_res, decimal=3)
def test_upsample(self):
class QUpsampleModule(torch.nn.Module):
def __init__(self):
super(QUpsampleModule, self).__init__()
self.quant1 = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
res = torch.nn.quantized.functional.interpolate(self.quant1(x), size=[6, 8], mode='nearest')
return self.dequant(res)
x = np.random.rand(1, 2, 3, 4).astype("float32")
self.generic_test(QUpsampleModule(), (x,), input_names=["x"])
def test_avg_pool2d(self):
class QAvgPool2dModule(torch.nn.Module):
def __init__(self):
super(QAvgPool2dModule, self).__init__()
self.quant1 = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
res = torch.nn.functional.avg_pool2d(self.quant1(x), kernel_size=2, stride=1, padding=0)
return self.dequant(res)
x = np.random.rand(1, 2, 8, 8).astype("float32")
self.generic_test(QAvgPool2dModule(), (x,), input_names=["x"])
def test_reshape(self):
class QReshapeModule(torch.nn.Module):
def __init__(self):
super(QReshapeModule, self).__init__()
self.quant1 = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
res = self.quant1(x).reshape((1, 2, 1, 12))
return self.dequant(res)
x = np.random.rand(1, 2, 3, 4).astype("float32")
self.generic_test(QReshapeModule(), (x,), input_names=["x"])
def test_slice(self):
class QSliceModule(torch.nn.Module):
def __init__(self):
super(QSliceModule, self).__init__()
self.quant1 = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
qx = self.quant1(x)
res = qx[:, 1:2]
return self.dequant(res)
x = np.random.rand(1, 2, 3, 4).astype("float32")
self.generic_test(QSliceModule(), (x,), input_names=["x"])
def test_cat(self):
class QConcatModule(torch.nn.Module):
def __init__(self):
super(QConcatModule, self).__init__()
self.quant1 = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x, y):
res = torch.ops.quantized.cat([self.quant1(x), self.quant1(y)], dim=1, scale=1.0, zero_point=0)
return self.dequant(res)
x = np.random.rand(1, 2, 3, 4).astype("float32")
y = np.random.rand(1, 4, 3, 4).astype("float32")
self.generic_test(QConcatModule(), (x, y,), input_names=["x", "y"])
if __name__ == '__main__':
unittest.main()