Skip to content

Commit

Permalink
Refactor CadenceQuantizer (#7540)
Browse files Browse the repository at this point in the history
Summary:

The current class structure is hard to cleanly extend. This diff:
- Makes `CadenceQuantizer` a base class
- Creates a `CadenceDefaultQuantizer` that is exactly the same as the previous `CadenceQuantizer` class
- Removes the qconfig from the `CadenceQuantizer`, since it really belongs to the `CadenceAtenQuantizer` (it is defined per op)
- Makes both the default qconfig and the default quantizer list module level variables

Using this structure will make it much cleaner to add new quantizers in the future.

Reviewed By: zonglinpeng

Differential Revision: D67645196
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jan 8, 2025
1 parent 1bac885 commit a8dc686
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 32 deletions.
7 changes: 5 additions & 2 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
print_memory_planning_info,
)
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceDefaultQuantizer,
CadenceQuantizer,
)
from executorch.backends.cadence.aot.utils import (
get_default_memory_config,
MemoryConfig,
Expand Down Expand Up @@ -136,7 +139,7 @@ def quantize_pt2(

# Instantiate the quantizer to CadenceQuantizer if not supplied
if not quantizer:
quantizer = CadenceQuantizer()
quantizer = CadenceDefaultQuantizer()

# Get converted graph module
converted_gm = convert_pt2(model, inputs, quantizer)
Expand Down
4 changes: 2 additions & 2 deletions backends/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
fuse_pt2,
)

from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
from executorch.backends.cadence.runtime import runtime
from executorch.backends.cadence.runtime.executor import BundledProgramManager
from executorch.exir import ExecutorchProgramManager
Expand Down Expand Up @@ -74,7 +74,7 @@ def export_model(
)

# Instantiate the quantizer
quantizer = CadenceQuantizer(qconfig)
quantizer = CadenceDefaultQuantizer(qconfig)

# Convert the model
converted_model = convert_pt2(model, example_inputs, quantizer)
Expand Down
67 changes: 41 additions & 26 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@

bias_qspec: Optional[QuantizationSpec] = None

_default_qconfig = QuantizationConfig(
act_qspec,
act_qspec,
wgt_qspec,
None,
)


class CadenceAtenQuantizer(Quantizer):
def __init__(
Expand Down Expand Up @@ -140,31 +147,39 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
return []


def get_cadence_default_quantizer_list_with_config(
quantization_config: QuantizationConfig,
) -> List[Quantizer]:
return [
CadenceAtenQuantizer(AddmmPattern(), quantization_config),
CadenceAtenQuantizer(BmmPattern(), quantization_config),
CadenceAtenQuantizer(Conv1dPattern(), quantization_config),
CadenceAtenQuantizer(Conv2dPattern(), quantization_config),
CadenceAtenQuantizer(LayerNormPattern(), quantization_config),
CadenceAtenQuantizer(LinearPattern(), quantization_config),
CadenceAtenQuantizer(MatmulPattern(), quantization_config),
CadenceAtenQuantizer(ReluPattern0(), quantization_config),
CadenceAtenQuantizer(ReluPattern1(), quantization_config),
]


class CadenceQuantizer(ComposableQuantizer):
def __init__(
self, quantization_config: Optional[QuantizationConfig] = None
) -> None:
static_qconfig = (
QuantizationConfig(
act_qspec,
act_qspec,
wgt_qspec,
None,
)
if not quantization_config
else quantization_config
)
"""
Generic CadenceQuantizer. Although it can be used directly, it is typically a base
class for explicitly defined quantizers (like CadenceDefaultQuantizer).
"""

super().__init__(
[
CadenceAtenQuantizer(AddmmPattern(), static_qconfig),
CadenceAtenQuantizer(BmmPattern(), static_qconfig),
CadenceAtenQuantizer(Conv1dPattern(), static_qconfig),
CadenceAtenQuantizer(Conv2dPattern(), static_qconfig),
CadenceAtenQuantizer(LayerNormPattern(), static_qconfig),
CadenceAtenQuantizer(LinearPattern(), static_qconfig),
CadenceAtenQuantizer(MatmulPattern(), static_qconfig),
CadenceAtenQuantizer(ReluPattern0(), static_qconfig),
CadenceAtenQuantizer(ReluPattern1(), static_qconfig),
]
)
def __init__(self, quantizers: List[Quantizer]) -> None:
super().__init__(quantizers)


class CadenceDefaultQuantizer(CadenceQuantizer):
"""
Default quantizer for Cadence backend.
"""

def __init__(self, qconfig: Optional[QuantizationConfig] = None) -> None:
if qconfig is None:
qconfig = _default_qconfig
quantizers = get_cadence_default_quantizer_list_with_config(qconfig)
super().__init__(quantizers)
4 changes: 2 additions & 2 deletions backends/cadence/aot/tests/test_remove_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from executorch.backends.cadence.aot.compiler import export_to_edge

from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
from executorch.backends.cadence.aot.remove_ops import (
RemoveAliasCopyOpPass,
RemoveCloneOpPass,
Expand Down Expand Up @@ -465,7 +465,7 @@ def forward(self, x):

# Run the standard quant/convert steps, but without fusing
# this leaves two redundant quant/dequant pairs to test with
quantizer = CadenceQuantizer()
quantizer = CadenceDefaultQuantizer()
model_exp = export_for_training(M(), (inp,)).module()
prepared_model = prepare_pt2e(model_exp, quantizer)
prepared_model(inp)
Expand Down

0 comments on commit a8dc686

Please sign in to comment.