Skip to content

Commit

Permalink
onnx 模型导出与算法更新 (#231)
Browse files Browse the repository at this point in the history
* 添加样例文件

* onnx模型导出与算法更新

* 为dump_torch_to_onnx函数添加了默认参数
* EXPORT_OVERLAPPED_CONFIG 现在是过时参数,你将使用TQC上的QuantizationVisiblity属性来进行导出控制。该属性有三个可选项:强制导出、TQC激活时导出、不导出。
* 修改了 exporter 逻辑以适配新的QuantizationVisiblity属性
* 修改了onnx qdq的导出逻辑,现在将尽可能消除对称量化中的激活函数。
* 修改了 graphwise analyser 的逻辑,现在允许分析多输出算子的误差
* 修改了 layerwise equalization 的逻辑,现在允许 include act,支持conv1d, conv2d conv3d, convtranpose1d, convtranspose2d, convtranspose3d, gemm, matmul
* 修复了 passive parameter pass 中的 pad 量化错误
* 修复了 quant alignment pass 中 pooling 算子的对齐错误
* 修复了 核心量化函数在启动 cuda kernel 的情况下无法处理 cpu tensor 的问题
* 修改 openvino 量化策略,负数部分现在可以取到-128(曾经是-127)
* 给 dsp quantizer 添加了一个新的量化类型

* 添加测试样例

* 修复ci错误
  • Loading branch information
ZhangZhiPku authored Aug 31, 2022
1 parent ec0429b commit 7883312
Show file tree
Hide file tree
Showing 20 changed files with 696 additions and 490 deletions.
4 changes: 3 additions & 1 deletion ppq/IR/base/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def copy(self, copy_value: bool = False):
'however its value is not an instance of torch.Tensor, '
'ppq will automaticall convert it to torch.Tensor now.')
self.value = convert_any_to_torch_tensor(self.value)
return Variable(name=self.name, value=self.value.clone(), is_parameter=self.is_parameter)
if isinstance(self.value, torch.Tensor):
value = self.value.clone()
return Variable(name=self.name, value=value, is_parameter=self.is_parameter)


class Operation(OperationBase, Serializable):
Expand Down
2 changes: 1 addition & 1 deletion ppq/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def dump_torch_to_onnx(
model: torch.nn.Module,
onnx_export_file: str,
input_shape: List[int],
input_dtype: torch.dtype,
input_dtype: torch.dtype = torch.float,
inputs: List[Any] = None,
device: str = 'cuda'):
"""
Expand Down
9 changes: 4 additions & 5 deletions ppq/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# PPQ System configuration
# You can modify following codes for your own purpose.


# Observer 中,最小 scale 限制,所有小于该值的 scale 将被该值覆盖
OBSERVER_MIN_SCALE = 1e-8
# Observer 中,最小 scale 的手动覆盖属性
Expand Down Expand Up @@ -64,9 +63,6 @@
DEFAULT_OPSET_VERSION = 11
STRICT_OPSET_CHECKING = False

# 导出 qdq 节点时是否需要导出状态已经是 overlap 的节点
EXPORT_OVERLAPPED_CONFIG = False

# LSTM 算子的权重缓存属性
LSTM_FLATTEN_WEIGHT_ATTRIB = 'LSTM_FLATTEN_WEIGHT_ATTRIB'
# GRU 算子的权重缓存属性
Expand All @@ -90,4 +86,7 @@
CHECKPOINT_TOLERANCE = 1

# 要做 Bias Correction 的算子种类
BIAS_CORRECTION_INTERST_TYPE = {'Conv', 'Gemm', 'ConvTranspose'}
BIAS_CORRECTION_INTERST_TYPE = {'Conv', 'Gemm', 'ConvTranspose'}

# 导出 qdq 节点时是否需要导出状态已经是 overlap 的节点
EXPORT_OVERLAPPED_CONFIG = False
55 changes: 35 additions & 20 deletions ppq/core/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@
"""

import time # for hash generation
from abc import abstractmethod
from enum import Enum
from typing import Any, Iterable, List

import torch

from .common import EXPORT_OVERLAPPED_CONFIG
from .storage import Serializable


class QuantizationVisiblity(Enum):
FORCE_EXPORT = 1
EXPOET_WHEN_ACTIVE = 2
INTERNAL = 3

class NetworkFramework(Enum):
PPL = 1
ONNX = 2
Expand Down Expand Up @@ -365,7 +370,7 @@ def __init__(
offset: Any = None,
observer_algorithm: str = None,
detail: Any = None,
require_export: bool = None,
visiblity: QuantizationVisiblity = QuantizationVisiblity.EXPOET_WHEN_ACTIVE,
state: QuantizationStates = QuantizationStates.INITIAL
):
"""Create a PPQ Tensor Quantization Configuration Instance.
Expand Down Expand Up @@ -395,7 +400,13 @@ def __init__(
detail (Any, optional): Only used by PPQ internal logic, detail is used to store some internal data,
you are not supposed to use it.
require_export (bool, optional): If require_export == True, PPQ exporter will export this TQC ignoring state checks.
visiblity (Visiblity): visiblity is the attribute that controls export logic.
Currently, there are 3 Visiblity level in PPQ:
if Visiblity == FORCE_EXPORT, ppq exporter will export this TQC
ignoring state check(even if current TQC has been overrlapped).
if Visiblity == EXPORT_WHEN_ACTIVD, ppq exporter will export this TQC only when it has been actived.
if Visiblity == INTERNAL, This TQC will not be exported.
state (QuantizationStates, optional):
Defaults to QuantizationStates.INITIAL, see QuantizationStates for more detail.
Expand All @@ -416,17 +427,25 @@ def __init__(
self.detail = {} if detail is None else detail
self._father_config = self # union-find
self._hash = self.__create_hash()
self._require_export = require_export
self._visiblity = visiblity
super().__init__()

@ abstractmethod
def export(self) -> str:
raise Exception('Implement this first')
def can_export(self) -> bool:
if self.visiblity == QuantizationVisiblity.INTERNAL: return False
type_check = isinstance(self.scale, torch.Tensor) and isinstance(self.offset, torch.Tensor)
valid_states = {QuantizationStates.BAKED, QuantizationStates.PASSIVE_BAKED}

if EXPORT_OVERLAPPED_CONFIG: valid_states.add(QuantizationStates.OVERLAPPED)
state_check = QuantizationStates.is_activated(self.state) or self.state in valid_states

if (state_check or self.visiblity == QuantizationVisiblity.FORCE_EXPORT):
if type_check: return True
return False

def __eq__(self, o: object) -> bool:
if not isinstance(o, TensorQuantizationConfig):
raise TypeError('Can only compare TensorQuantizationConfig object '\
'with another TensorQuantizationConfig object.')
raise TypeError('Can only compare TensorQuantizationConfig object '
'with another TensorQuantizationConfig object.')
return self._hash == o._hash

def __str__(self) -> str:
Expand Down Expand Up @@ -509,17 +528,13 @@ def is_revisable(self):
})

@ property
def exportable(self) -> bool:
value_check = isinstance(self.scale, torch.Tensor)
if self._require_export is None:
state_check = QuantizationStates.can_export(self.state)
return (value_check and state_check)
else: return (self._require_export and value_check)

@ exportable.setter
def exportable(self, export_override: bool):
self._require_export = export_override

def visiblity(self) -> bool:
return self._visiblity

@ visiblity.setter
def visiblity(self, visiblity: bool):
self._visiblity = visiblity

@ property
def scale(self) -> torch.Tensor:
if self.dominated_by == self: return self._scale
Expand Down
7 changes: 1 addition & 6 deletions ppq/parser/caffe_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph):
for operation in graph.operations.values():
if not isinstance(operation, QuantableOperation): continue
for config, var in operation.config_with_variable:
if not QuantizationStates.can_export(config.state):
raise PermissionError(
'Can not export quant config cause not all quantization configurations '
'have been correctly initialized(or some of them has been deactivated). '
f'Operation {operation.name} has an invalid quantization config({config.state}) '
f'at variable {var.name}.')
if not config.can_export(): continue

# PATCH 2021.11.25
# REMOVE BIAS FROM CONFIGURATION
Expand Down
3 changes: 3 additions & 0 deletions ppq/parser/ncnn_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph):
if op.is_computing_op and isinstance(op, QuantableOperation):
fd.write(f'{op.name}_param_0 ')
param_cfg = op.config.input_quantization_config[1]
if not param_cfg.can_export(): continue

assert param_cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}\
and param_cfg.observer_algorithm in {'minmax', 'Minmax'} and \
param_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL)
Expand All @@ -32,6 +34,7 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph):
for s in scale:
fd.write('%f '% s)
fd.write('\n')

for op in topo_order:
if op.is_computing_op and isinstance(op, QuantableOperation):
fd.write(f'{op.name} ')
Expand Down
5 changes: 2 additions & 3 deletions ppq/parser/nxp_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ def export(self, file_path: str, graph: BaseGraph,
if variable.is_parameter and not export_param: continue
for config in configs:
if config is None: continue # source_op can be None
if config.state in {QuantizationStates.ACTIVATED, QuantizationStates.BAKED,
QuantizationStates.OVERLAPPED, QuantizationStates.PASSIVE_BAKED}:
if config.state == QuantizationStates.OVERLAPPED: config = config.dominated_by
if config.can_export():

tensor_range = config.scale * pow(2, config.num_of_bits - 1)
min_val, max_val = -tensor_range, tensor_range - config.scale
min_tensor = numpy_helper.from_array(
Expand Down
Loading

0 comments on commit 7883312

Please sign in to comment.