Skip to content

Commit

Permalink
Qnn batchnorm support input with rank 2 (#21469)
Browse files Browse the repository at this point in the history
### Description
Qnn BatchNorm support input with rank 2
Update Quantization script to quantize BatchNorm bias using int32

---------

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
  • Loading branch information
HectorSVC and justinchuby authored Jul 25, 2024
1 parent 4167b68 commit c235178
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ bool BatchNormalizationNodeGroupSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) {
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 3)) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,15 +392,23 @@ class BatchNormOpBuilder : public BaseOpBuilder {
const double rmin,
QnnQuantParamsWrapper& quant_param,
std::vector<uint8_t>& raw_tensor) const {
bool symmetric = false;
if (info.quant_param.IsQuantized()) {
raw_tensor.resize(double_tensor.size());
size_t data_size = double_tensor.size();
// QNN BatchNorm int32 bias requires symmetric quantizated
if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_32) {
data_size *= sizeof(int32_t);
symmetric = true;
}
raw_tensor.resize(data_size);
float scale = 0.0f;
int zero_point = 0;
int32_t zero_point = 0;
ORT_RETURN_IF_ERROR(utils::GetQuantParams(static_cast<float>(rmin),
static_cast<float>(rmax),
info.qnn_data_type,
scale,
zero_point));
zero_point,
symmetric));
quant_param = QnnQuantParamsWrapper(scale, zero_point);
for (size_t i = 0; i < double_tensor.size(); ++i) {
// onnx only supports 8 bits quantization
Expand All @@ -411,6 +419,10 @@ class BatchNormOpBuilder : public BaseOpBuilder {
} else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) {
int8_t quant_value = static_cast<int8_t>(quant_value_int);
raw_tensor[i] = *reinterpret_cast<uint8_t*>(&quant_value);
} else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_32) {
int32_t quant_value = static_cast<int32_t>(quant_value_int);
size_t pos = i * sizeof(int32_t);
std::memcpy(&raw_tensor[pos], reinterpret_cast<uint8_t*>(&quant_value), sizeof(int32_t));
} else {
// TODO(adrianlizarraga): Should support 16-bit quantization as well.
ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", info.qnn_data_type);
Expand Down Expand Up @@ -444,8 +456,7 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape of input 0.");
const size_t input_rank = input_shape.size();

ORT_RETURN_IF(input_rank <= 2 || input_rank > 4,
"QNN BatchNorm only supports input ranks of size 3 or 4.");
ORT_RETURN_IF(input_rank > 4, "QNN BatchNorm only supports input ranks of size <= 4.");

const uint32_t num_channels = input_shape[1];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
if (is_quantized_tensor) {
ORT_RETURN_IF_ERROR(utils::GetQnnDataType(true, type_proto, qnn_data_type));
float scale = 0.0f;
int zero_point = 0;
int32_t zero_point = 0;
float rmax = 1.0f;
float rmin = 1.0f;
ORT_RETURN_IF_ERROR(utils::GetQuantParams(rmin,
Expand Down
23 changes: 19 additions & 4 deletions onnxruntime/core/providers/qnn/builder/qnn_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,9 @@ Status GetQminQmax(const Qnn_DataType_t qnn_data_type,
} else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) {
qmin = static_cast<T>(std::numeric_limits<uint16_t>::min());
qmax = static_cast<T>(std::numeric_limits<uint16_t>::max());
} else if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_32) {
qmin = static_cast<T>(std::numeric_limits<int32_t>::min());
qmax = static_cast<T>(std::numeric_limits<int32_t>::max());
} else {
ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type);
}
Expand All @@ -519,15 +522,27 @@ Status GetQuantParams(float rmin,
float rmax,
const Qnn_DataType_t qnn_data_type,
float& scale,
int& zero_point) {
int32_t& zero_point,
bool symmetric) {
std::tie(rmin, rmax) = CheckMinMax(rmin, rmax);
if (symmetric) {
float abs_max = std::max(abs(rmax), abs(rmin));
rmax = abs_max;
rmin = -abs_max;
}

float qmin = 0.0f;
float qmax = 255.0f;
ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax));

scale = (rmax - rmin) / (qmax - qmin);
const float initial_zero_point = qmin - (rmin / scale);
zero_point = static_cast<int>(RoundHalfToEven(Saturate(qmax, qmin, initial_zero_point)));
float initial_zero_point = 0.0f;
if (symmetric) {
initial_zero_point = std::round(rmin + rmax) / 2;
} else {
initial_zero_point = qmin - (rmin / scale);
}
zero_point = static_cast<int32_t>(RoundHalfToEven(Saturate(qmax, qmin, initial_zero_point)));
// To match QNN quantization definition
zero_point = 0 - zero_point;
return Status::OK();
Expand All @@ -541,7 +556,7 @@ double Dequantize(int32_t offset, float scale, const double quant_value) {

Status Quantize(const double double_value,
const float scale,
const int zero_point,
const int32_t zero_point,
const Qnn_DataType_t qnn_data_type,
int& quant_value) {
int qmin = 0;
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/qnn/builder/qnn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,14 @@ Status GetQuantParams(float rmin,
float rmax,
const Qnn_DataType_t qnn_data_type,
float& scale,
int& zero_point);
int32_t& zero_point,
bool symmetric = false);

double Dequantize(int32_t offset, float scale, const double quant_value);

Status Quantize(const double double_value,
const float scale,
const int zero_point,
const int32_t zero_point,
const Qnn_DataType_t qnn_data_type,
int& quant_value);

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/tools/quantization/operators/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, onnx_quantizer, onnx_node):

def quantize(self):
node = self.node
assert node.op_type == "InstanceNormalization" or node.op_type == "LayerNormalization"
assert node.op_type in {"InstanceNormalization", "LayerNormalization", "BatchNormalization"}

# Input
self.quantizer.quantize_activation_tensor(node.input[0])
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/python/tools/quantization/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"Where": QDQWhere,
"InstanceNormalization": QDQNormalization,
"LayerNormalization": QDQNormalization,
"BatchNormalization": QDQNormalization,
}


Expand Down
99 changes: 66 additions & 33 deletions onnxruntime/test/providers/qnn/batch_norm_htp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ template <typename FLOAT_TYPE>
static GetTestModelFn BuildBatchNormTestCase(const TestInputDef<FLOAT_TYPE>& input_def,
const TestInputDef<FLOAT_TYPE>& scale_def,
const TestInputDef<FLOAT_TYPE>& bias_def) {
ORT_ENFORCE(input_def.IsRawData()); // Need raw data to compute mean and variance inputs.
ORT_ENFORCE(input_def.GetShape().size() > 2); // Need at least rank 3 data for convenience.
ORT_ENFORCE(input_def.IsRawData()); // Need raw data to compute mean and variance inputs.

return [input_def, scale_def, bias_def](ModelTestBuilder& builder) {
const auto& input_shape = input_def.GetShape();
Expand All @@ -103,45 +102,39 @@ static GetTestModelFn BuildBatchNormTestCase(const TestInputDef<FLOAT_TYPE>& inp
};
}

template <typename InputQType, typename ScaleQType, typename BiasQType>
template <typename InputQType, typename ScaleQType>
GetTestQDQModelFn<InputQType> BuildQDQBatchNormTestCase(const TestInputDef<float>& input_def,
const TestInputDef<float>& scale_def,
const TestInputDef<float>& bias_def) {
ORT_ENFORCE(input_def.IsRawData()); // Need raw data to compute mean and variance inputs.
ORT_ENFORCE(input_def.GetShape().size() > 2); // Need at least rank 3 data for convenience.
ORT_ENFORCE(input_def.IsRawData()); // Need raw data to compute mean and variance inputs.

return [input_def, scale_def, bias_def](ModelTestBuilder& builder,
std::vector<QuantParams<InputQType>>& output_qparams) {
const auto& input_shape = input_def.GetShape();
const auto& input_data = input_def.GetRawData();
const int64_t num_channels = input_shape[1];

bool symmetric = sizeof(InputQType) == sizeof(uint16_t);
NodeArg* input = MakeTestInput(builder, input_def);
QuantParams<InputQType> input_qparams = GetTestInputQuantParams<InputQType>(input_def);
QuantParams<InputQType> input_qparams = GetTestInputQuantParams<InputQType>(input_def, symmetric);
NodeArg* input_qdq = AddQDQNodePair<InputQType>(builder, input, input_qparams.scale, input_qparams.zero_point);

NodeArg* scale = MakeTestInput(builder, scale_def);
QuantParams<ScaleQType> scale_qparams = GetTestInputQuantParams<ScaleQType>(scale_def);
NodeArg* scale_qdq = AddQDQNodePair<ScaleQType>(builder, scale, scale_qparams.scale, scale_qparams.zero_point);

NodeArg* bias = MakeTestInput(builder, bias_def);
QuantParams<BiasQType> bias_qparams = GetTestInputQuantParams<BiasQType>(bias_def);
NodeArg* bias_qdq = AddQDQNodePair<BiasQType>(builder, bias, bias_qparams.scale, bias_qparams.zero_point);
NodeArg* bias_qdq;
// bias (as int32) => DQ =>
bias_qdq = MakeTestQDQBiasInput(builder, bias_def, input_qparams.scale * scale_qparams.scale, true);

std::vector<float> mean_vals(num_channels);
std::vector<float> var_vals(num_channels);
ComputeChannelMeanAndVar(input_data, input_shape, mean_vals, var_vals);

NodeArg* mean = builder.MakeInitializer<float>({num_channels}, mean_vals);
QuantParams<InputQType> mean_qparams = GetDataQuantParams(mean_vals);
NodeArg* mean_qdq = AddQDQNodePair<InputQType>(builder, mean, mean_qparams.scale, mean_qparams.zero_point);

NodeArg* var = builder.MakeInitializer<float>({num_channels}, var_vals);
QuantParams<InputQType> var_qparams = GetDataQuantParams(var_vals);
NodeArg* var_qdq = AddQDQNodePair<InputQType>(builder, var, var_qparams.scale, var_qparams.zero_point);

auto* batchnorm_output = builder.MakeIntermediate();
builder.AddNode("BatchNormalization", {input_qdq, scale_qdq, bias_qdq, mean_qdq, var_qdq},
builder.AddNode("BatchNormalization", {input_qdq, scale_qdq, bias_qdq, mean, var},
{batchnorm_output});

AddQDQNodePairWithOutputAsGraphOutput<InputQType>(builder, batchnorm_output, output_qparams[0].scale, output_qparams[0].zero_point);
Expand All @@ -155,6 +148,7 @@ GetTestQDQModelFn<InputQType> BuildQDQBatchNormTestCase(const TestInputDef<float
* \param input_shape The input's shape.
* \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None).
*/
template <typename InputQType, typename ScaleQType>
static void RunBatchNormQDQTest(const TestInputDef<float>& input_def,
const TestInputDef<float>& scale_def,
const TestInputDef<float>& bias_def,
Expand All @@ -169,9 +163,9 @@ static void RunBatchNormQDQTest(const TestInputDef<float>& input_def,

// Runs model with DQ-> InstanceNorm -> Q and compares the outputs of the CPU and QNN EPs.
TestQDQModelAccuracy(BuildBatchNormTestCase(input_def, scale_def, bias_def),
BuildQDQBatchNormTestCase<uint8_t, uint8_t, uint8_t>(input_def, scale_def, bias_def),
BuildQDQBatchNormTestCase<InputQType, ScaleQType>(input_def, scale_def, bias_def),
provider_options,
11,
21,
expected_ep_assignment,
tolerance);
}
Expand Down Expand Up @@ -199,31 +193,69 @@ static void RunBatchNormFP16Test(const TestInputDef<float>& input_def,
expected_ep_assignment);
}

// BatchNor QDQ model, input with rank 2.
TEST_F(QnnHTPBackendTests, BatchNormRank2) {
constexpr int64_t num_channels = 2;

RunBatchNormQDQTest<uint8_t, uint8_t>(TestInputDef<float>({4, num_channels}, false,
{-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f}), // Input data
TestInputDef<float>({num_channels}, true, {1.0f, 2.0f}), // Scale initializer
TestInputDef<float>({num_channels}, true, {1.1f, 2.1f}), // Bias initializer
ExpectedEPNodeAssignment::All);
}

// TODO: FIX TRANSLATION!!!
// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit.
// Use an input of rank 3.
// Accuracy issue with Linux simulator, not sure with Android device
// Inaccuracy detected for output 'output_0', element 1
// output_range=4.8666362762451172, tolerance=0.40000000596046448%.
// Expected val (f32@CPU_EP): 1.0999999046325684
// qdq@QNN_EP val: -0.17176364362239838 (err: 1.2717635631561279, err/output_range: 26.132291793823242%)
// qdq@CPU_EP val: 1.1069211959838867 (err: 0.0069212913513183594, err/output_range: 0.14221921563148499%)
// abs(qdq@QNN_EP - qdq@CPU_EP) / output_range = 25.990072250366211%
//
// Inaccuracy detected for output 'output_0', element 2
// output_range=4.8666362762451172, tolerance=0.40000000596046448%.
// Expected val (f32@CPU_EP): 2.3247356414794922
// qdq@QNN_EP val: -0.17176364362239838 (err: 2.4964993000030518, err/output_range: 51.298248291015625%)
// qdq@CPU_EP val: 2.3474364280700684 (err: 0.022700786590576172, err/output_range: 0.46645742654800415%)
#if defined(_WIN32)
TEST_F(QnnHTPBackendTests, BatchNorm1D) {
constexpr int64_t num_channels = 2;

RunBatchNormQDQTest(TestInputDef<float>({1, num_channels, 3}, false, {-5.0f, -4.0f, -3.0f, 0.0f, 2.0f, 5.0f}), // Input data
TestInputDef<float>({num_channels}, true, {1.0f, 2.0f}), // Scale initializer
TestInputDef<float>({num_channels}, true, {1.1f, 2.1f}), // Bias initializer
ExpectedEPNodeAssignment::All);
RunBatchNormQDQTest<uint8_t, uint8_t>(TestInputDef<float>({1, num_channels, 3}, false,
{-5.0f, -4.0f, -3.0f, 0.0f, 2.0f, 5.0f}), // Input data
TestInputDef<float>({num_channels}, true, {1.0f, 2.0f}), // Scale initializer
TestInputDef<float>({num_channels}, true, {1.1f, 2.1f}), // Bias initializer
ExpectedEPNodeAssignment::All);
}
#endif

// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit.
// Use an input of rank 4.
TEST_F(QnnHTPBackendTests, BatchNorm2D_a8w8) {
constexpr int64_t num_channels = 2;
std::vector<float> input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f,
-7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f};

RunBatchNormQDQTest<uint8_t, uint8_t>(TestInputDef<float>({2, num_channels, 2, 2}, false, input_data), // Input data
TestInputDef<float>({num_channels}, true, {1.0f, 2.0f}), // Scale initializer
TestInputDef<float>({num_channels}, true, {1.1f, 2.1f}), // Bias initializer
ExpectedEPNodeAssignment::All);
}

// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit.
// Use an input of rank 4.
TEST_F(QnnHTPBackendTests, BatchNorm2D) {
TEST_F(QnnHTPBackendTests, BatchNorm2D_a16w8) {
constexpr int64_t num_channels = 2;
std::vector<float> input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f,
-7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f};

RunBatchNormQDQTest(TestInputDef<float>({2, num_channels, 2, 2}, false, input_data), // Input data
TestInputDef<float>({num_channels}, true, {1.0f, 2.0f}), // Scale initializer
TestInputDef<float>({num_channels}, true, {1.1f, 2.1f}), // Bias initializer
ExpectedEPNodeAssignment::All,
// Require a slightly increased tolerance on Windows ARM64 (from 0.4% to 0.6%).
QDQTolerance(0.006f));
RunBatchNormQDQTest<uint16_t, uint8_t>(TestInputDef<float>({2, num_channels, 2, 2}, false, input_data), // Input data
TestInputDef<float>({num_channels}, true, {1.0f, 2.0f}), // Scale initializer
TestInputDef<float>({num_channels}, true, {1.1f, 2.1f}), // Bias initializer
ExpectedEPNodeAssignment::All);
}

// Test FP16 BatchNormalization on the HTP backend.
Expand Down Expand Up @@ -272,10 +304,11 @@ TEST_F(QnnHTPBackendTests, BatchNorm_FP32_as_FP16) {
TEST_F(QnnHTPBackendTests, BatchNorm3D) {
constexpr int64_t num_channels = 2;
constexpr int64_t num_elems = 1 * num_channels * 3 * 4 * 5;
RunBatchNormQDQTest(TestInputDef<float>({1, num_channels, 3, 4, 5}, false, std::vector<float>(num_elems)), // Input data (all zeros)
TestInputDef<float>({num_channels}, true, {1.0f, 2.0f}), // Scale initializer
TestInputDef<float>({num_channels}, true, {1.1f, 2.1f}), // Bias initializer
ExpectedEPNodeAssignment::None);
RunBatchNormQDQTest<uint8_t, uint8_t>(TestInputDef<float>({1, num_channels, 3, 4, 5}, false,
std::vector<float>(num_elems)), // Input data (all zeros)
TestInputDef<float>({num_channels}, true, {1.0f, 2.0f}), // Scale initializer
TestInputDef<float>({num_channels}, true, {1.1f, 2.1f}), // Bias initializer
ExpectedEPNodeAssignment::None);
}

#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/test/providers/qnn/qnn_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct QuantParams {
symmetric);
}

static QuantParams<QType> Compute(float rmin, float rmax, QType qmin, QType qmax, bool symmetric = false) {
static QuantParams<QType> Compute(float rmin, float rmax, float qmin, float qmax, bool symmetric = false) {
// Ensure a minimum range of 0.0001 (required by QNN)
rmax = std::max(rmax, rmin + 0.0001f);

Expand All @@ -56,8 +56,8 @@ struct QuantParams {
rmin = -abs_max;
}

float qmin_flt = static_cast<float>(qmin);
float qmax_flt = static_cast<float>(qmax);
float qmin_flt = qmin;
float qmax_flt = qmax;
const float scale = (rmax - rmin) / (qmax_flt - qmin_flt);
float initial_zero_point = 0.0f;

Expand Down

0 comments on commit c235178

Please sign in to comment.