Skip to content

Commit

Permalink
Add support for FakeConvert in CompressQuantizeWeights
Browse files Browse the repository at this point in the history
Ticket: CVS-129925
  • Loading branch information
mateusztabaka authored and beleiuandrei committed Jan 25, 2024
1 parent 1e846ad commit ab300e9
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 54 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -10,13 +10,16 @@ namespace ov {
namespace pass {

class CompressQuantizeWeights;
class CompressWeightsWithFakeQuantize;
class CompressWeightsWithFakeConvert;

} // namespace pass
} // namespace ov

/*
CompressQuantizeWeights transformation goal is to pre-quantize data to minimize runtime calculations with constant
data. To achieve this goal we perform FakeQuantize decomposition to separate quantization from dequantization in it.
CompressWeightsWithFakeQuantize transformation goal is to pre-quantize data to minimize runtime calculations with
constant data. To achieve this goal we perform FakeQuantize decomposition to separate quantization from
dequantization in it.
Initial graph (FakeQuantize where all inputs are Constants):
Expand Down Expand Up @@ -58,7 +61,46 @@ class CompressQuantizeWeights;
With that we can skip same calculations in the runtime and make loading of such sub-graphs to the plugin faster.
Additionally zero point can be fused to weights if it doesn't affect accuracy.
*/
class ov::pass::CompressQuantizeWeights : public ov::pass::MatcherPass {
class ov::pass::CompressWeightsWithFakeQuantize : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("CompressWeightsWithFakeQuantize", "0");

CompressWeightsWithFakeQuantize();
};

/*
CompressWeightsWithFakeConvert replaces FakeConvert node with constant inputs to the following subgraph:
+----------+
| Constant |
| (float8) }
+----+-----+
|
v
+----------+
| Convert |
| (float32)|
+----+-----+
|
v
+----------+ +--------+
| Subtract |<----| -shift |
+----+-----+ +--------+
|
v
+----------+ +---------+
| Multiply |<----| 1/scale |
+----------+ +---------+
*/
class ov::pass::CompressWeightsWithFakeConvert : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("CompressWeightsWithFakeConvert", "0");

CompressWeightsWithFakeConvert();
};

class ov::pass::CompressQuantizeWeights : public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("CompressQuantizeWeights", "0");
CompressQuantizeWeights();
Expand Down
182 changes: 133 additions & 49 deletions src/common/offline_transformations/src/compress_quantize_weigths.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -8,6 +8,7 @@
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/fake_convert.hpp"
#include "openvino/op/fake_quantize.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/subtract.hpp"
Expand All @@ -17,6 +18,7 @@
#include "openvino/reference/autobroadcast_binop.hpp"
#include "openvino/reference/convert.hpp"
#include "openvino/reference/fake_quantize.hpp"
#include "transformations/utils/utils.hpp"
#include "validation_util.hpp"

static bool has_dequantization_subgraph(const std::shared_ptr<ov::Node>& fq,
Expand Down Expand Up @@ -60,7 +62,7 @@ static void replace_with_dequantize_subgraph(const std::shared_ptr<ov::op::v0::F
bool zero_point_is_zero,
const ov::Tensor& zero_point_tensor = {});

ov::pass::CompressQuantizeWeights::CompressQuantizeWeights() {
ov::pass::CompressWeightsWithFakeQuantize::CompressWeightsWithFakeQuantize() {
auto weights_const_pattern = pattern::wrap_type<op::v0::Constant>();
auto weights_convert_pattern = pattern::wrap_type<op::v0::Convert>({weights_const_pattern});
OutputVector weights_options{weights_const_pattern, weights_convert_pattern};
Expand Down Expand Up @@ -220,8 +222,93 @@ ov::pass::CompressQuantizeWeights::CompressQuantizeWeights() {
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(fq_pattern, "CompressQuantizeWeights");
this->register_matcher(m, callback);
auto m = std::make_shared<pattern::Matcher>(fq_pattern, "CompressWeightsWithFakeQuantize");
register_matcher(m, callback);
}

static std::shared_ptr<ov::op::v0::Constant> get_fake_convert_shift(
const std::shared_ptr<ov::op::v13::FakeConvert>& fake_convert) {
if (fake_convert->get_input_size() < 3)
return nullptr;
const auto shift = ov::as_type_ptr<ov::op::v0::Constant>(fake_convert->get_input_node_shared_ptr(2));
if (!shift)
return nullptr;
float value = -1.0f;
if (!ov::op::util::get_single_value(shift, value) || value != 0.0f)
return shift;
return nullptr;
}

ov::pass::CompressWeightsWithFakeConvert::CompressWeightsWithFakeConvert() {
auto weights_const_pattern = pattern::wrap_type<op::v0::Constant>();
auto weights_convert_pattern = pattern::wrap_type<op::v0::Convert>({weights_const_pattern});
OutputVector weights_options{weights_const_pattern, weights_convert_pattern};
auto weights_pattern = std::make_shared<pattern::op::Or>(weights_options);
auto fake_convert_pattern = pattern::wrap_type<op::v13::FakeConvert>(
{weights_pattern, pattern::wrap_type<op::v0::Constant>(), pattern::wrap_type<op::v0::Constant>()});
auto fake_convert_pattern2 =
pattern::wrap_type<op::v13::FakeConvert>({weights_pattern, pattern::wrap_type<op::v0::Constant>()});
auto root = std::make_shared<pattern::op::Or>(OutputVector{fake_convert_pattern, fake_convert_pattern2});

matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_map();
const auto fake_convert = ov::as_type_ptr<op::v13::FakeConvert>(m.get_match_root());
auto weights = pattern_map.at(weights_const_pattern);

NodeVector from{weights, fake_convert, fake_convert->get_input_node_shared_ptr(1)};
NodeRegistry node_registry;

if (weights->get_output_element_type(0) != fake_convert->get_input_element_type(0)) {
weights = std::make_shared<op::v0::Convert>(weights, fake_convert->get_input_element_type(0));
}
const auto scale = fake_convert->input_value(1);
weights = std::make_shared<op::v1::Multiply>(weights, scale);
const auto shift = get_fake_convert_shift(fake_convert);
if (shift) {
from.push_back(shift);
weights = std::make_shared<op::v1::Subtract>(weights, shift);
}
const auto destination_type = element::Type(fake_convert->get_destination_type());
const auto weights_convert = std::make_shared<op::v0::Convert>(weights, destination_type);
auto compressed_weights = ov::util::constantfold_subgraph(weights_convert);
if (!compressed_weights) {
return false;
}
node_registry.add(compressed_weights);

const auto convert =
node_registry.make<op::v0::Convert>(compressed_weights, fake_convert->get_input_element_type(0));
const auto inv_scale = ov::util::constantfold_subgraph(
std::make_shared<op::v1::Power>(scale,
op::v0::Constant::create(scale.get_element_type(), Shape{}, {-1.0f})));
if (!inv_scale)
return false;
node_registry.add(inv_scale);
std::shared_ptr<op::v1::Multiply> multiply;
if (shift) {
// TODO: check if shift can be fused to weights and eliminate it
const auto neg_shift = ov::util::constantfold_subgraph(std::make_shared<op::v0::Negative>(shift));
if (!neg_shift)
return false;
node_registry.add(neg_shift);
const auto subtract = node_registry.make<op::v1::Subtract>(convert, neg_shift);
multiply = node_registry.make<op::v1::Multiply>(subtract, inv_scale);
} else {
multiply = node_registry.make<op::v1::Multiply>(convert, inv_scale);
}

compressed_weights->set_friendly_name(weights->get_friendly_name());
multiply->set_friendly_name(fake_convert->get_friendly_name());

copy_runtime_info(from, node_registry.get());

replace_node(fake_convert, multiply);

return true;
};

auto m = std::make_shared<pattern::Matcher>(root, "CompressWeightsWithFakeConvert");
register_matcher(m, callback);
}

static ov::Tensor tensor_from_constant(const std::shared_ptr<ov::op::v0::Constant>& constant) {
Expand Down Expand Up @@ -686,7 +773,6 @@ std::shared_ptr<ov::op::v0::Constant> compress_quantized_weights(
const std::shared_ptr<ov::Node>& convert,
const std::shared_ptr<ov::Node>& zero_point,
bool& can_fuse_zero_point) {
std::shared_ptr<ov::op::v0::Constant> new_weights;
const auto& weights_shape = weights->get_shape();
const auto& type = weights->get_element_type();
const auto& low_precision_type = convert->get_output_element_type(0);
Expand Down Expand Up @@ -715,7 +801,6 @@ std::shared_ptr<ov::op::v0::Constant> compress_quantized_weights(
zero_point_constant->get_shape(),
fq->get_levels(),
can_fuse_zero_point);
break;
}
case ov::element::f16: {
return compress_quantized_weights_internal(low_precision_type,
Expand All @@ -733,7 +818,6 @@ std::shared_ptr<ov::op::v0::Constant> compress_quantized_weights(
zero_point_constant->get_shape(),
fq->get_levels(),
can_fuse_zero_point);
break;
}
case ov::element::bf16: {
return compress_quantized_weights_internal(low_precision_type,
Expand All @@ -751,7 +835,6 @@ std::shared_ptr<ov::op::v0::Constant> compress_quantized_weights(
zero_point_constant->get_shape(),
fq->get_levels(),
can_fuse_zero_point);
break;
}
default:
return nullptr;
Expand Down Expand Up @@ -832,57 +915,58 @@ std::shared_ptr<ov::op::v0::Constant> compress_quantized_weights(
bool zero_point_is_zero,
const ov::Tensor& zero_point_tensor,
bool& can_fuse_zero_point) {
std::shared_ptr<ov::op::v0::Constant> new_weights;
const auto& weights_shape = weights->get_shape();
const auto& type = weights->get_element_type();
switch (type) {
case ov::element::f32: {
new_weights = compress_quantized_weights_internal(low_precision_type,
weights->get_data_ptr<float>(),
weights_shape,
input_low->get_data_ptr<float>(),
input_low->get_shape(),
input_high->get_data_ptr<float>(),
input_low->get_shape(),
zero_point_tensor.data<float>(),
zero_point_tensor.get_shape(),
levels,
zero_point_is_zero,
can_fuse_zero_point);
break;
return compress_quantized_weights_internal(low_precision_type,
weights->get_data_ptr<float>(),
weights_shape,
input_low->get_data_ptr<float>(),
input_low->get_shape(),
input_high->get_data_ptr<float>(),
input_low->get_shape(),
zero_point_tensor.data<float>(),
zero_point_tensor.get_shape(),
levels,
zero_point_is_zero,
can_fuse_zero_point);
}
case ov::element::f16: {
new_weights = compress_quantized_weights_internal(low_precision_type,
weights->get_data_ptr<ov::float16>(),
weights_shape,
input_low->get_data_ptr<ov::float16>(),
input_low->get_shape(),
input_high->get_data_ptr<ov::float16>(),
input_low->get_shape(),
zero_point_tensor.data<ov::float16>(),
zero_point_tensor.get_shape(),
levels,
zero_point_is_zero,
can_fuse_zero_point);
break;
return compress_quantized_weights_internal(low_precision_type,
weights->get_data_ptr<ov::float16>(),
weights_shape,
input_low->get_data_ptr<ov::float16>(),
input_low->get_shape(),
input_high->get_data_ptr<ov::float16>(),
input_low->get_shape(),
zero_point_tensor.data<ov::float16>(),
zero_point_tensor.get_shape(),
levels,
zero_point_is_zero,
can_fuse_zero_point);
}
case ov::element::bf16: {
new_weights = compress_quantized_weights_internal(low_precision_type,
weights->get_data_ptr<ov::bfloat16>(),
weights_shape,
input_low->get_data_ptr<ov::bfloat16>(),
input_low->get_shape(),
input_high->get_data_ptr<ov::bfloat16>(),
input_low->get_shape(),
zero_point_tensor.data<ov::bfloat16>(),
zero_point_tensor.get_shape(),
levels,
zero_point_is_zero,
can_fuse_zero_point);
break;
return compress_quantized_weights_internal(low_precision_type,
weights->get_data_ptr<ov::bfloat16>(),
weights_shape,
input_low->get_data_ptr<ov::bfloat16>(),
input_low->get_shape(),
input_high->get_data_ptr<ov::bfloat16>(),
input_low->get_shape(),
zero_point_tensor.data<ov::bfloat16>(),
zero_point_tensor.get_shape(),
levels,
zero_point_is_zero,
can_fuse_zero_point);
}
default:
return nullptr;
}
return new_weights;
return nullptr;
}

ov::pass::CompressQuantizeWeights::CompressQuantizeWeights() {
add_matcher<CompressWeightsWithFakeQuantize>();
add_matcher<CompressWeightsWithFakeConvert>();
}
Loading

0 comments on commit ab300e9

Please sign in to comment.