Skip to content

Commit

Permalink
fix remaining unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Oct 26, 2023
1 parent 170e5c9 commit 47b41b6
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 35 deletions.
58 changes: 37 additions & 21 deletions onnxruntime/python/tools/quantization/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,12 @@ def add_reduce_min_max(tensor_name, reduce_op_name):
)

self.model.graph.node.extend([reduce_node, reshape_node])
# This function assumes the output type is the same as the first input type.
onnx_type = self.model.graph.input[0].type.tensor_type.elem_type
io_dict = {o.name: o for o in self.model.graph.output}
io_dict.update({i.name: i for i in self.model.graph.input})
if tensor_name in io_dict:
onnx_type = io_dict[tensor_name].type.tensor_type.elem_type
else:
raise ValueError(f"Unable to guess tensor type for tensor {tensor_name!r}")
self.model.graph.output.append(helper.make_tensor_value_info(reduce_output, onnx_type, [1]))

for tensor in tensors:
Expand Down Expand Up @@ -681,41 +685,53 @@ def collect_absolute_value(self, name_to_arr):
Collect histogram on absolute value
"""
for tensor, data_arr in name_to_arr.items():
if not isinstance(data_arr, np.ndarray):
if isinstance(data_arr, list):
for arr in data_arr:
if not isinstance(arr, np.ndarray):
raise ValueError(f"Unexpected type {type(arr)} for tensor={tensor!r}")
dtypes = set(a.dtype for a in arr)
if len(dtypes) != 1:
raise ValueError(
f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}"
)
data_arr_np = np.asarray(data_arr)
elif not isinstance(data_arr, np.ndarray):
raise ValueError(f"Unexpected type {type(data_arr)} for tensor={tensor!r}")
data_arr = data_arr.flatten() # noqa: PLW2901
if data_arr.size > 0:
min_value = np.min(data_arr)
max_value = np.max(data_arr)
else:
data_arr_np = data_arr
data_arr_np = data_arr_np.flatten()
if data_arr_np.size > 0:
min_value = np.min(data_arr_np)
max_value = np.max(data_arr_np)
else:
min_value = 0
max_value = 0

data_arr = np.absolute(data_arr) # only consider absolute value # noqa: PLW2901
data_arr_np = np.absolute(data_arr_np) # only consider absolute value

if tensor not in self.histogram_dict:
# first time it uses num_bins to compute histogram.
hist, hist_edges = np.histogram(data_arr, bins=self.num_bins)
hist_edges = hist_edges.astype(data_arr.dtype)
assert data_arr.dtype != np.float64
hist, hist_edges = np.histogram(data_arr_np, bins=self.num_bins)
hist_edges = hist_edges.astype(data_arr_np.dtype)
assert data_arr_np.dtype != np.float64
self.histogram_dict[tensor] = (hist, hist_edges, min_value, max_value)
else:
old_histogram = self.histogram_dict[tensor]
old_min = old_histogram[2]
old_max = old_histogram[3]
old_hist = old_histogram[0]
old_hist_edges = old_histogram[1]
temp_amax = np.max(data_arr)
temp_amax = np.max(data_arr_np)
if temp_amax > old_hist_edges[-1]:
# increase the number of bins
width = old_hist_edges[1] - old_hist_edges[0]
# NOTE: np.arange may create an extra bin after the one containing temp_amax
new_bin_edges = np.arange(old_hist_edges[-1] + width, temp_amax + width, width)
old_hist_edges = np.hstack((old_hist_edges, new_bin_edges))
hist, hist_edges = np.histogram(data_arr, bins=old_hist_edges)
hist_edges = hist_edges.astype(data_arr.dtype)
hist, hist_edges = np.histogram(data_arr_np, bins=old_hist_edges)
hist_edges = hist_edges.astype(data_arr_np.dtype)
hist[: len(old_hist)] += old_hist
assert data_arr.dtype != np.float64
assert data_arr_np.dtype != np.float64
self.histogram_dict[tensor] = (hist, hist_edges, min(old_min, min_value), max(old_max, max_value))

def collect_value(self, name_to_arr):
Expand All @@ -730,8 +746,8 @@ def collect_value(self, name_to_arr):
min_value = np.min(data_arr)
max_value = np.max(data_arr)
else:
min_value = 0
max_value = 0
min_value = np.array(0, dtype=data_arr.dtype)
max_value = np.array(0, dtype=data_arr.dtype)

threshold = max(abs(min_value), abs(max_value))

Expand Down Expand Up @@ -818,16 +834,16 @@ def compute_percentile(self):
idx_right = np.searchsorted(cdf, percentile / 100.0)

thresholds_dict[tensor] = (
-float(hist_edges[idx_right]),
float(hist_edges[idx_right]),
-np.array(hist_edges[idx_right], dtype=hist_edges.dtype),
np.array(hist_edges[idx_right], dtype=hist_edges.dtype),
)
else:
percent_to_cut_one_side = (100.0 - percentile) / 200.0
idx_right = np.searchsorted(cdf, 1.0 - percent_to_cut_one_side)
idx_left = np.searchsorted(cdf, percent_to_cut_one_side)
thresholds_dict[tensor] = (
float(hist_edges[idx_left]),
float(hist_edges[idx_right]),
np.array(hist_edges[idx_left], dtype=hist_edges.dtype),
np.array(hist_edges[idx_right], dtype=hist_edges.dtype),
)
min_value = histogram[2]
max_value = histogram[3]
Expand Down
7 changes: 1 addition & 6 deletions onnxruntime/python/tools/quantization/onnx_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,9 +724,6 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N
[output_name, scale_name, zp_name],
ql_node_name,
)
self.quantized_value_map[input_name] = QuantizedValue(
input_name, output_name, scale_name, zp_name, qType
)
else:
(
scale_name,
Expand All @@ -740,10 +737,8 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N
[output_name],
ql_node_name,
)
self.quantized_value_map[input_name] = QuantizedValue(
input_name, output_name, scale_name, zp_name, qType
)

self.quantized_value_map[input_name] = QuantizedValue(input_name, output_name, scale_name, zp_name, qType)
return [*nodes, qlinear_node]

def set_quant_scale_zp(self, tensor_name, value):
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/test/python/quantization/op_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,9 @@ def check_qtype_by_node_type(testcase, model_to_check, check_list):
for check_item in input_output_check_list:
tensor_name = node.input[check_item[1]] if check_item[0] == "i" else node.output[check_item[1]]
if tensor_name not in value_infos and tensor_name not in initializers:
raise AssertionError(f"Unable to find tensor_name={tensor_name!r}\n{model}")
raise AssertionError(
f"Unable to find tensor_name={tensor_name!r} in {list(sorted(value_infos))}\n{model}"
)
if tensor_name in value_infos:
vi = value_infos[tensor_name]
testcase.assertTrue(vi.type.HasField("tensor_type"))
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/test/python/quantization/test_op_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,8 @@ def test_pad_with_empty_string_input_name(self):
)

model_fp32 = TestOpQuatizerPad.construct_model_add_pad_add(name=name, shape=shape, final_name="output")
op_types = [n.op_type for n in model_fp32.graph.node]
self.assertEqual(["Add", "Pad", "Add"], op_types)

onnx.save(model_fp32, model_fp32_path)

Expand All @@ -506,13 +508,11 @@ def test_pad_with_empty_string_input_name(self):
)

model_i8 = onnx.load(model_i8_path)
print(model_i8)

# Assert quantization really happens.
self.assertEqual(model_i8.graph.node[0].op_type, "QuantizeLinear")
self.assertEqual(model_i8.graph.node[1].op_type, "QLinearAdd")
self.assertEqual(model_i8.graph.node[2].op_type, "Pad")
self.assertEqual(model_i8.graph.node[3].op_type, "QLinearAdd")
self.assertEqual(model_i8.graph.node[4].op_type, "DequantizeLinear")
op_types = [n.op_type for n in model_i8.graph.node]
self.assertEqual(["QuantizeLinear", "QLinearAdd", "Pad", "QLinearAdd", "DequantizeLinear"], op_types)

for node in model_i8.graph.node:
# Examine no empty string flows to quantization process.
Expand Down
2 changes: 0 additions & 2 deletions onnxruntime/test/python/quantization/test_op_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,9 @@ def quantize_where_test(self, activation_type, weight_type, extra_options={}):

def test_quantize_where_u8u8(self):
self.quantize_where_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": True})
print(__name__)

def test_quantize_where_u8u8_no_force_quantize_no_input_check(self):
self.quantize_where_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": False})
print(__name__)


if __name__ == "__main__":
Expand Down

0 comments on commit 47b41b6

Please sign in to comment.