Skip to content

Commit

Permalink
[TransposeOptimizer] Support Unsqueeze/Transpose of input consumed by…
Browse files Browse the repository at this point in the history
… per-axis DQ (microsoft#21821)

### Description
Follow-up to: microsoft#21793

- Support looking past a per-axis DQ to do in-place Unsqueeze/Transpose
of initializers
- Support looking past a per-axis DQ to cancel a Transpose or Squeeze.

### Test models
For all test models, the transpose optimizer pushes a Transpose through
a Mul's input[0]. The Mul's input[1] is optionally unsqueezed and then
transposed.

### I. Test in-place unsqueeze and transpose of per-axis quantized
weight
Original model has input[1] with shape (3,)
<details><summary>click to expand model image</summary>
<img
src="https://github.com/user-attachments/assets/37b6f60c-77d2-4bd3-8ca2-58dc7c88a304"
/>
</details>

Optimized model has input[1] with shape (1, 3, 1, 1). The initializer
was unsqueezed and transposed in-place.
<details><summary>click expand model image</summary>
<img
src="https://github.com/user-attachments/assets/adb72757-a164-400c-bfef-2a05f0e35825"
/>
</details>

### II. Test canceling existing Squeeze before per-axis DQ
Original model has input[1] that is squeezed.
<details><summary>click expand model image</summary>
<img
src="https://github.com/user-attachments/assets/f27e6742-b563-42a9-ad06-bb3178b0ceb8"
/>
</details>

Optimized model unsqueezed and transposed input[1]. The original squeeze
was removed due to the unsqueeze, leaving only the Transpose.
<details><summary>click expand model image</summary>
<img
src="https://github.com/user-attachments/assets/e56261d4-eba6-4a9f-847b-dcd33548dd07"
/>
</details>

### III. Test canceling existing Transpose before per-axis DQ
Original model has input[1] that is transposed.
<details><summary>click expand model image</summary>
<img
src="https://github.com/user-attachments/assets/f157e04a-572a-479d-8e3b-cf57954df5c0"
/>
</details>

Optimized model transposed input[1], thus canceling the existing
transpose.
<details><summary>click expand model image</summary>
<img
src="https://github.com/user-attachments/assets/63d742ce-3762-4ab2-bdb0-1b507886da9d"
/>
</details>

### IV. Test QDQ fix-up of Transpose/Unsqueeze for per-axis quantization
Original model has input[1] that can be broadcasted.
<details><summary>click expand model image</summary>
<img
src="https://github.com/user-attachments/assets/96c0092c-22ec-486d-882e-e2cb59ffe324"
/>
</details>

The main transpose optimization loop inserts float32 Unsqueeze and
Transpose after the DQ. The qdq fix-up pass inserts new per-axis Q/DQ
ops after the inserted nodes.
<details><summary>click expand model image</summary>
<img
src="https://github.com/user-attachments/assets/b6f89c11-974d-4b35-922f-11effdf06883"
/>
</details>


### Motivation and Context
Enables the TransposeOptimizer to support more models with per-axis QDQ
nodes. Per-axis quantization can improve model accuracy and is used by
EPs like QNN.

---------

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
  • Loading branch information
adrianlizarraga and edgchen1 authored Sep 6, 2024
1 parent 23f6604 commit b011f6f
Show file tree
Hide file tree
Showing 9 changed files with 763 additions and 253 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <gsl/gsl>
#include <unordered_map>
#include <vector>

Expand Down Expand Up @@ -59,7 +60,7 @@ struct OptimizerCtx {
/// <returns>{0}</returns>
inline std::vector<size_t> FirstInput(OptimizerCtx&, api::NodeRef&) { return {0}; }

std::vector<int64_t> InvertPerm(const std::vector<int64_t>& perm);
std::vector<int64_t> InvertPerm(gsl::span<const int64_t> perm);

// Transpose all inputs and all outputs
bool HandleSimpleNode(HandlerArgs& args);
Expand Down
316 changes: 292 additions & 24 deletions onnxruntime/test/optimizer/transpose_optimizer_test.cc

Large diffs are not rendered by default.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import numpy as np
import onnx


def subgraph_1d_const_input_dq(inputs, initializers, nodes) -> str:
"""
Creates mul_weight -> DQ. mul_weight is a constant of rank 1.
"""
mul_weight_i8_data = np.array([1, 2, 3], dtype=np.int8)
mul_weight = onnx.numpy_helper.from_array(mul_weight_i8_data, "mul_weight")
initializers.append(mul_weight)

dq_output_name = "mul_input_1"
nodes.append(
onnx.helper.make_node(
"DequantizeLinear",
["mul_weight", "mul_weight_scales", "mul_weight_zps"],
[dq_output_name],
name="dq_mul_input_1",
axis=0,
)
)

return dq_output_name


def subgraph_1d_input_dq(inputs, initializers, nodes) -> str:
"""
Creates input1 -> DQ. input1 is a graph input of rank 1.
"""
input1_shape = (3,)
inputs.append(onnx.helper.make_tensor_value_info("input1", onnx.TensorProto.INT8, input1_shape))

dq_output_name = "mul_input_1"
nodes.append(
onnx.helper.make_node(
"DequantizeLinear",
["input1", "mul_weight_scales", "mul_weight_zps"],
[dq_output_name],
name="dq_mul_input_1",
axis=0,
)
)

return dq_output_name


def subgraph_4d_input_squeeze_dq(inputs, initializers, nodes) -> str:
"""
Creates input1 -> Squeeze -> DQ. input1 is a graph input of rank 4.
"""
input1_shape = (1, 1, 1, 3)
inputs.append(onnx.helper.make_tensor_value_info("input1", onnx.TensorProto.INT8, input1_shape))

axes_data = np.array([0, 1, 2], dtype=np.int64)
initializers.append(onnx.numpy_helper.from_array(axes_data, "axes_const"))
nodes.append(onnx.helper.make_node("Squeeze", ["input1", "axes_const"], ["squeeze_out"], name="squeeze_node"))

dq_output_name = "mul_input_1"
nodes.append(
onnx.helper.make_node(
"DequantizeLinear",
["squeeze_out", "mul_weight_scales", "mul_weight_zps"],
[dq_output_name],
name="dq_mul_input_1",
axis=0,
)
)

return dq_output_name


def subgraph_4d_input_transpose_dq(inputs, initializers, nodes) -> str:
"""
Creates input1 -> Transpose -> DQ. input1 is a graph input of rank 4.
"""
input1_shape = (1, 3, 1, 1)
inputs.append(onnx.helper.make_tensor_value_info("input1", onnx.TensorProto.INT8, input1_shape))

perm = [0, 2, 3, 1] # To channel-last
nodes.append(onnx.helper.make_node("Transpose", ["input1"], ["tp_out_"], perm=perm, name="transpose_"))

dq_output_name = "mul_input_1"
nodes.append(
onnx.helper.make_node(
"DequantizeLinear",
["tp_out_", "mul_weight_scales", "mul_weight_zps"],
[dq_output_name],
name="dq_mul_input_1",
axis=-1,
)
)

return dq_output_name


def make_model(model_path: str, build_mul_input_1_subgraph):
"""
Creates a QDQ model with a per-axis DQ input that is Unsqueezed and Transposed by the Transpose optimizer.
"""
input0_shape = (1, 3, 4, 4)

inputs = [onnx.helper.make_tensor_value_info("input0", onnx.TensorProto.FLOAT, input0_shape)]
outputs = [onnx.helper.make_tensor_value_info("output0", onnx.TensorProto.FLOAT, None)]

mul_weight_scales_data = np.array([1.0, 1.0, 1.0], dtype=np.float32)
mul_weight_zps_data = np.array([0, 0, 0], dtype=np.int8)

initializers = [
onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "scale_1"),
onnx.numpy_helper.from_array(np.array(128, dtype=np.uint8), "zp_128"),
onnx.numpy_helper.from_array(np.array(1.0 / 255.0, dtype=np.float32), "scale_inv_255"),
onnx.numpy_helper.from_array(np.array(0, dtype=np.uint8), "zp_0"),
onnx.numpy_helper.from_array(mul_weight_scales_data, "mul_weight_scales"),
onnx.numpy_helper.from_array(mul_weight_zps_data, "mul_weight_zps"),
]
nodes = []

# Transpose to channel-last
tp0_node = onnx.helper.make_node("Transpose", ["input0"], ["tp0_out"], name="tp0_node", perm=(0, 2, 3, 1))
nodes.append(tp0_node)

# Q_0
q0_node = onnx.helper.make_node("QuantizeLinear", ["tp0_out", "scale_1", "zp_128"], ["q0_out"], name="q0_node")
nodes.append(q0_node)

# DQ_0
dq0_node = onnx.helper.make_node("DequantizeLinear", ["q0_out", "scale_1", "zp_128"], ["dq0_out"], name="dq0_node")
nodes.append(dq0_node)

# Sigmoid
sigmoid_node = onnx.helper.make_node("Sigmoid", ["dq0_out"], ["sigmoid_out"], name="sigmoid_node")
nodes.append(sigmoid_node)

# Q_1
q1_node = onnx.helper.make_node(
"QuantizeLinear", ["sigmoid_out", "scale_inv_255", "zp_0"], ["q1_out"], name="q1_node"
)
nodes.append(q1_node)

# DQ_1
dq1_node = onnx.helper.make_node(
"DequantizeLinear", ["q1_out", "scale_inv_255", "zp_0"], ["dq1_out"], name="dq1_node"
)
nodes.append(dq1_node)

# DQ for mul input[1]
mul_input_1_name = build_mul_input_1_subgraph(inputs, initializers, nodes)

# Mul
mul_node = onnx.helper.make_node("Mul", ["dq1_out", mul_input_1_name], ["mul_out"], name="mul_node")
nodes.append(mul_node)

# Q_2
q2_node = onnx.helper.make_node("QuantizeLinear", ["mul_out", "scale_inv_255", "zp_0"], ["q2_out"], name="q2_node")
nodes.append(q2_node)

# DQ_2
dq2_node = onnx.helper.make_node(
"DequantizeLinear", ["q2_out", "scale_inv_255", "zp_0"], ["dq2_out"], name="dq2_node"
)
nodes.append(dq2_node)

# Transpose to channel-first
tp1_node = onnx.helper.make_node("Transpose", ["dq2_out"], ["output0"], name="tp1_node", perm=(0, 3, 1, 2))
nodes.append(tp1_node)

graph = onnx.helper.make_graph(
nodes,
"transpose_opt_unsqueeze_dq_axis",
inputs,
outputs,
initializer=initializers,
)
opset_imports = [
onnx.helper.make_opsetid("", 19),
]
qdq_model = onnx.helper.make_model(graph, opset_imports=opset_imports)

print("[INFO]: Running onnx.checker on qdq model")
qdq_model = onnx.shape_inference.infer_shapes(qdq_model)
onnx.checker.check_model(qdq_model, True)

print(f"[INFO]: Saving {model_path}")
onnx.save_model(qdq_model, model_path)


if __name__ == "__main__":
make_model(
"transpose_optimizer_qdq_fixup_unsqueeze_per_axis_dq.onnx",
subgraph_1d_input_dq,
)
make_model(
"transpose_optimizer_in_place_transpose_unsqueeze_per_axis_dq.onnx",
subgraph_1d_const_input_dq,
)
make_model(
"transpose_optimizer_cancel_squeeze_per_axis_dq.onnx",
subgraph_4d_input_squeeze_dq,
)
make_model(
"transpose_optimizer_cancel_transpose_per_axis_dq.onnx",
subgraph_4d_input_transpose_dq,
)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit b011f6f

Please sign in to comment.