diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 74d064204ecca..533b2197bf30c 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -69,6 +69,7 @@ class TensorData: _floats = frozenset(["avg", "std", "lowest", "highest", "hist_edges"]) def __init__(self, **kwargs): + self._attrs = list(kwargs.keys()) for k, v in kwargs.items(): if k not in TensorData._allowed: raise ValueError(f"Unexpected value {k!r} not in {TensorData._allowed}.") @@ -91,6 +92,12 @@ def avg_std(self): raise AttributeError(f"Attributes 'avg' and/or 'std' missing in {dir(self)}.") return (self.avg, self.std) + def to_dict(self): + # This is needed to serialize the data into JSON. + data = {k: getattr(self, k) for k in self._attrs} + data["CLS"] = self.__class__.__name__ + return data + class TensorsData: def __init__(self, calibration_method, data: Dict[str, Union[TensorData, Tuple]]): @@ -125,12 +132,24 @@ def __setitem__(self, key, value): raise RuntimeError(f"Only an existing tensor can be modified, {key!r} is not.") self.data[key] = value + def keys(self): + return self.data.keys() + def values(self): return self.data.values() def items(self): return self.data.items() + def to_dict(self): + # This is needed to serialize the data into JSON. + data = { + "CLS": self.__class__.__name__, + "data": self.data, + "calibration_method": self.calibration_method, + } + return data + class CalibrationMethod(Enum): MinMax = 0 diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 0fdef4ef6f6d3..9228ad33130f2 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -671,21 +671,41 @@ def write_calibration_table(calibration_cache, dir="."): import json import flatbuffers + import numpy as np import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable + from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData logging.info(f"calibration cache: {calibration_cache}") + class MyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (TensorData, TensorsData)): + return obj.to_dict() + if isinstance(obj, np.ndarray): + return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"} + if isinstance(obj, CalibrationMethod): + return {"CLS": obj.__class__.__name__, "value": str(obj)} + return json.JSONEncoder.default(self, obj) + + json_data = json.dumps(calibration_cache, cls=MyEncoder) + with open(os.path.join(dir, "calibration.json"), "w") as file: - file.write(json.dumps(calibration_cache)) # use `json.loads` to do the reverse + file.write(json_data) # use `json.loads` to do the reverse # Serialize data using FlatBuffers + zero = np.array(0) builder = flatbuffers.Builder(1024) key_value_list = [] for key in sorted(calibration_cache.keys()): values = calibration_cache[key] - value = str(max(abs(values[0]), abs(values[1]))) + d_values = values.to_dict() + floats = [ + float(d_values.get("highest", zero).item()), + float(d_values.get("lowest", zero).item()), + ] + value = str(max(floats)) flat_key = builder.CreateString(key) flat_value = builder.CreateString(value) @@ -724,9 +744,14 @@ def write_calibration_table(calibration_cache, dir="."): # write plain text with open(os.path.join(dir, "calibration.cache"), "w") as file: for key in sorted(calibration_cache.keys()): - value = calibration_cache[key] - s = key + " " + str(max(abs(value[0]), abs(value[1]))) - file.write(s) + values = calibration_cache[key] + d_values = values.to_dict() + floats = [ + float(d_values.get("highest", zero).item()), + float(d_values.get("lowest", zero).item()), + ] + value = key + " " + str(max(floats)) + file.write(value) file.write("\n") diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index d3718de1aa56a..b99c11abf6d2c 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -22,8 +22,8 @@ create_clip_node, ) -from onnxruntime.quantization import QDQQuantizer, QuantFormat, QuantType, quantize_static -from onnxruntime.quantization.calibrate import TensorData +from onnxruntime.quantization import QDQQuantizer, QuantFormat, QuantType, quantize_static, write_calibration_table +from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData class TestQDQFormat(unittest.TestCase): @@ -1720,6 +1720,11 @@ def test_int4_qdq_per_channel_conv(self): size_ratio = weight_quant_init.ByteSize() / unpacked_size self.assertLess(size_ratio, 0.55) + def test_json_serialization(self): + td = TensorData(lowest=np.array([0.1], dtype=np.float32), highest=np.array([1.1], dtype=np.float32)) + new_calibrate_tensors_range = TensorsData(CalibrationMethod.MinMax, {"td": td}) + write_calibration_table(new_calibrate_tensors_range) + if __name__ == "__main__": unittest.main()