diff --git a/src/common/offline_transformations/include/compress_quantize_weights.hpp b/src/common/offline_transformations/include/compress_quantize_weights.hpp index 356ff01195ae3f..597b50828494a5 100644 --- a/src/common/offline_transformations/include/compress_quantize_weights.hpp +++ b/src/common/offline_transformations/include/compress_quantize_weights.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2023 Intel Corporation +// Copyright (C) 2018-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -10,13 +10,16 @@ namespace ov { namespace pass { class CompressQuantizeWeights; +class CompressWeightsWithFakeQuantize; +class CompressWeightsWithFakeConvert; } // namespace pass } // namespace ov /* - CompressQuantizeWeights transformation goal is to pre-quantize data to minimize runtime calculations with constant - data. To achieve this goal we perform FakeQuantize decomposition to separate quantization from dequantization in it. + CompressWeightsWithFakeQuantize transformation goal is to pre-quantize data to minimize runtime calculations with + constant data. To achieve this goal we perform FakeQuantize decomposition to separate quantization from + dequantization in it. Initial graph (FakeQuantize where all inputs are Constants): @@ -58,7 +61,46 @@ class CompressQuantizeWeights; With that we can skip same calculations in the runtime and make loading of such sub-graphs to the plugin faster. Additionally zero point can be fused to weights if it doesn't affect accuracy. */ -class ov::pass::CompressQuantizeWeights : public ov::pass::MatcherPass { +class ov::pass::CompressWeightsWithFakeQuantize : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("CompressWeightsWithFakeQuantize", "0"); + + CompressWeightsWithFakeQuantize(); +}; + +/* + CompressWeightsWithFakeConvert replaces FakeConvert node with constant inputs to the following subgraph: + + +----------+ + | Constant | + | (float8) } + +----+-----+ + | + v + +----------+ + | Convert | + | (float32)| + +----+-----+ + | + v + +----------+ +--------+ + | Subtract |<----| -shift | + +----+-----+ +--------+ + | + v + +----------+ +---------+ + | Multiply |<----| 1/scale | + +----------+ +---------+ + +*/ +class ov::pass::CompressWeightsWithFakeConvert : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("CompressWeightsWithFakeConvert", "0"); + + CompressWeightsWithFakeConvert(); +}; + +class ov::pass::CompressQuantizeWeights : public ov::pass::GraphRewrite { public: OPENVINO_RTTI("CompressQuantizeWeights", "0"); CompressQuantizeWeights(); diff --git a/src/common/offline_transformations/src/compress_quantize_weigths.cpp b/src/common/offline_transformations/src/compress_quantize_weigths.cpp index c708517445add5..12a84f784230ad 100644 --- a/src/common/offline_transformations/src/compress_quantize_weigths.cpp +++ b/src/common/offline_transformations/src/compress_quantize_weigths.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2023 Intel Corporation +// Copyright (C) 2018-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -8,6 +8,7 @@ #include "openvino/op/constant.hpp" #include "openvino/op/convert.hpp" #include "openvino/op/divide.hpp" +#include "openvino/op/fake_convert.hpp" #include "openvino/op/fake_quantize.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/subtract.hpp" @@ -17,6 +18,7 @@ #include "openvino/reference/autobroadcast_binop.hpp" #include "openvino/reference/convert.hpp" #include "openvino/reference/fake_quantize.hpp" +#include "transformations/utils/utils.hpp" #include "validation_util.hpp" static bool has_dequantization_subgraph(const std::shared_ptr& fq, @@ -60,7 +62,7 @@ static void replace_with_dequantize_subgraph(const std::shared_ptr(); auto weights_convert_pattern = pattern::wrap_type({weights_const_pattern}); OutputVector weights_options{weights_const_pattern, weights_convert_pattern}; @@ -220,8 +222,93 @@ ov::pass::CompressQuantizeWeights::CompressQuantizeWeights() { return true; }; - auto m = std::make_shared(fq_pattern, "CompressQuantizeWeights"); - this->register_matcher(m, callback); + auto m = std::make_shared(fq_pattern, "CompressWeightsWithFakeQuantize"); + register_matcher(m, callback); +} + +static std::shared_ptr get_fake_convert_shift( + const std::shared_ptr& fake_convert) { + if (fake_convert->get_input_size() < 3) + return nullptr; + const auto shift = ov::as_type_ptr(fake_convert->get_input_node_shared_ptr(2)); + if (!shift) + return nullptr; + float value = -1.0f; + if (!ov::op::util::get_single_value(shift, value) || value != 0.0f) + return shift; + return nullptr; +} + +ov::pass::CompressWeightsWithFakeConvert::CompressWeightsWithFakeConvert() { + auto weights_const_pattern = pattern::wrap_type(); + auto weights_convert_pattern = pattern::wrap_type({weights_const_pattern}); + OutputVector weights_options{weights_const_pattern, weights_convert_pattern}; + auto weights_pattern = std::make_shared(weights_options); + auto fake_convert_pattern = pattern::wrap_type( + {weights_pattern, pattern::wrap_type(), pattern::wrap_type()}); + auto fake_convert_pattern2 = + pattern::wrap_type({weights_pattern, pattern::wrap_type()}); + auto root = std::make_shared(OutputVector{fake_convert_pattern, fake_convert_pattern2}); + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_map(); + const auto fake_convert = ov::as_type_ptr(m.get_match_root()); + auto weights = pattern_map.at(weights_const_pattern); + + NodeVector from{weights, fake_convert, fake_convert->get_input_node_shared_ptr(1)}; + NodeRegistry node_registry; + + if (weights->get_output_element_type(0) != fake_convert->get_input_element_type(0)) { + weights = std::make_shared(weights, fake_convert->get_input_element_type(0)); + } + const auto scale = fake_convert->input_value(1); + weights = std::make_shared(weights, scale); + const auto shift = get_fake_convert_shift(fake_convert); + if (shift) { + from.push_back(shift); + weights = std::make_shared(weights, shift); + } + const auto destination_type = element::Type(fake_convert->get_destination_type()); + const auto weights_convert = std::make_shared(weights, destination_type); + auto compressed_weights = ov::util::constantfold_subgraph(weights_convert); + if (!compressed_weights) { + return false; + } + node_registry.add(compressed_weights); + + const auto convert = + node_registry.make(compressed_weights, fake_convert->get_input_element_type(0)); + const auto inv_scale = ov::util::constantfold_subgraph( + std::make_shared(scale, + op::v0::Constant::create(scale.get_element_type(), Shape{}, {-1.0f}))); + if (!inv_scale) + return false; + node_registry.add(inv_scale); + std::shared_ptr multiply; + if (shift) { + // TODO: check if shift can be fused to weights and eliminate it + const auto neg_shift = ov::util::constantfold_subgraph(std::make_shared(shift)); + if (!neg_shift) + return false; + node_registry.add(neg_shift); + const auto subtract = node_registry.make(convert, neg_shift); + multiply = node_registry.make(subtract, inv_scale); + } else { + multiply = node_registry.make(convert, inv_scale); + } + + compressed_weights->set_friendly_name(weights->get_friendly_name()); + multiply->set_friendly_name(fake_convert->get_friendly_name()); + + copy_runtime_info(from, node_registry.get()); + + replace_node(fake_convert, multiply); + + return true; + }; + + auto m = std::make_shared(root, "CompressWeightsWithFakeConvert"); + register_matcher(m, callback); } static ov::Tensor tensor_from_constant(const std::shared_ptr& constant) { @@ -686,7 +773,6 @@ std::shared_ptr compress_quantized_weights( const std::shared_ptr& convert, const std::shared_ptr& zero_point, bool& can_fuse_zero_point) { - std::shared_ptr new_weights; const auto& weights_shape = weights->get_shape(); const auto& type = weights->get_element_type(); const auto& low_precision_type = convert->get_output_element_type(0); @@ -715,7 +801,6 @@ std::shared_ptr compress_quantized_weights( zero_point_constant->get_shape(), fq->get_levels(), can_fuse_zero_point); - break; } case ov::element::f16: { return compress_quantized_weights_internal(low_precision_type, @@ -733,7 +818,6 @@ std::shared_ptr compress_quantized_weights( zero_point_constant->get_shape(), fq->get_levels(), can_fuse_zero_point); - break; } case ov::element::bf16: { return compress_quantized_weights_internal(low_precision_type, @@ -751,7 +835,6 @@ std::shared_ptr compress_quantized_weights( zero_point_constant->get_shape(), fq->get_levels(), can_fuse_zero_point); - break; } default: return nullptr; @@ -832,57 +915,58 @@ std::shared_ptr compress_quantized_weights( bool zero_point_is_zero, const ov::Tensor& zero_point_tensor, bool& can_fuse_zero_point) { - std::shared_ptr new_weights; const auto& weights_shape = weights->get_shape(); const auto& type = weights->get_element_type(); switch (type) { case ov::element::f32: { - new_weights = compress_quantized_weights_internal(low_precision_type, - weights->get_data_ptr(), - weights_shape, - input_low->get_data_ptr(), - input_low->get_shape(), - input_high->get_data_ptr(), - input_low->get_shape(), - zero_point_tensor.data(), - zero_point_tensor.get_shape(), - levels, - zero_point_is_zero, - can_fuse_zero_point); - break; + return compress_quantized_weights_internal(low_precision_type, + weights->get_data_ptr(), + weights_shape, + input_low->get_data_ptr(), + input_low->get_shape(), + input_high->get_data_ptr(), + input_low->get_shape(), + zero_point_tensor.data(), + zero_point_tensor.get_shape(), + levels, + zero_point_is_zero, + can_fuse_zero_point); } case ov::element::f16: { - new_weights = compress_quantized_weights_internal(low_precision_type, - weights->get_data_ptr(), - weights_shape, - input_low->get_data_ptr(), - input_low->get_shape(), - input_high->get_data_ptr(), - input_low->get_shape(), - zero_point_tensor.data(), - zero_point_tensor.get_shape(), - levels, - zero_point_is_zero, - can_fuse_zero_point); - break; + return compress_quantized_weights_internal(low_precision_type, + weights->get_data_ptr(), + weights_shape, + input_low->get_data_ptr(), + input_low->get_shape(), + input_high->get_data_ptr(), + input_low->get_shape(), + zero_point_tensor.data(), + zero_point_tensor.get_shape(), + levels, + zero_point_is_zero, + can_fuse_zero_point); } case ov::element::bf16: { - new_weights = compress_quantized_weights_internal(low_precision_type, - weights->get_data_ptr(), - weights_shape, - input_low->get_data_ptr(), - input_low->get_shape(), - input_high->get_data_ptr(), - input_low->get_shape(), - zero_point_tensor.data(), - zero_point_tensor.get_shape(), - levels, - zero_point_is_zero, - can_fuse_zero_point); - break; + return compress_quantized_weights_internal(low_precision_type, + weights->get_data_ptr(), + weights_shape, + input_low->get_data_ptr(), + input_low->get_shape(), + input_high->get_data_ptr(), + input_low->get_shape(), + zero_point_tensor.data(), + zero_point_tensor.get_shape(), + levels, + zero_point_is_zero, + can_fuse_zero_point); } default: return nullptr; } - return new_weights; + return nullptr; +} + +ov::pass::CompressQuantizeWeights::CompressQuantizeWeights() { + add_matcher(); + add_matcher(); } diff --git a/src/common/transformations/tests/utils/compress_quantize_weights.cpp b/src/common/transformations/tests/utils/compress_quantize_weights.cpp index 4c1eef7cea5489..553dab11476a1a 100644 --- a/src/common/transformations/tests/utils/compress_quantize_weights.cpp +++ b/src/common/transformations/tests/utils/compress_quantize_weights.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2023 Intel Corporation +// Copyright (C) 2018-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -391,3 +391,164 @@ TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsNonConstantInput) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } + +using CompressWeightsWithFakeConvertParams = std::tuple; // float8 type + +class CompressWeightsNoZeroPoint : public TransformationTestsF, + public testing::WithParamInterface {}; + +TEST_P(CompressWeightsNoZeroPoint, FakeConvert) { + const auto& param = GetParam(); + bool zero_point_absent = std::get<0>(param); + std::string destination_type = std::get<1>(param); + + { + auto weights = op::v0::Constant::create(element::f32, + Shape{3, 1, 2, 2}, + {-0.01448f, + -0.02314f, + -0.02244f, + -0.00090f, + 0.024261f, + 0.031921f, + 0.034088f, + -0.0497f, + -0.0588f, + -0.04541f, + -0.01281f, + 0.009109f}); + auto scale = op::v0::Constant::create(element::f32, Shape{3, 1, 1, 1}, {54.50976f}); + std::shared_ptr fake_convert; + if (zero_point_absent) { + fake_convert = std::make_shared(weights, scale, destination_type); + } else { + auto shift = op::v0::Constant::create(element::f32, Shape{3, 1, 1, 1}, {0.0f}); + fake_convert = std::make_shared(weights, scale, shift, destination_type); + } + model = std::make_shared(fake_convert, ParameterVector{}); + + manager.register_pass(); + } + + { + // TODO: change when it's allowed to create a fp8 constant from fp32 values + std::vector weights_data = destination_type == "f8e4m3" ? std::vector{0xb5, // -0.8125 + 0xba, // -1.25 + 0xba, // -1.25 + 0x95, // -0.0507812 + 0x3b, // 1.375 + 0x3e, // 1.75 + 0x3f, // 1.875 + 0xc3, // -2.75 + 0xc5, // -3.25 + 0xc2, // -2.5 + 0xb3, // -0.6875 + 0x30} // 0.5 + : + + std::vector{0xba, // -0.75 + 0xbd, // -1.25 + 0xbd, // -1.25 + 0xaa, // -0.046875 + 0x3d, // 1.25 + 0x3f, // 1.75 + 0x3f, // 1.75 + 0xc1, // -2.5 + 0xc2, // -3 + 0xc1, // -2.5 + 0xba, // -0.75 + 0x38}; // 0.5 + + auto weights = + std::make_shared(element::Type(destination_type), Shape{3, 1, 2, 2}, weights_data.data()); + auto convert = std::make_shared(weights, element::f32); + auto scale = op::v0::Constant::create(element::f32, Shape{3, 1, 1, 1}, {0.01834533}); + auto multiply = std::make_shared(convert, scale); + model_ref = std::make_shared(multiply, ParameterVector{}); + } + + m_abs_threshold = 1e-6f; + m_rel_threshold = 1e-6f; + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); +} + +INSTANTIATE_TEST_SUITE_P(CompressQuantizeWeights, + CompressWeightsNoZeroPoint, + testing::Combine(testing::Values(false, true), testing::Values("f8e4m3", "f8e5m2"))); + +class CompressWeightsWithZeroPoint : public TransformationTestsF, public testing::WithParamInterface {}; + +TEST_P(CompressWeightsWithZeroPoint, FakeConvert) { + const auto& destination_type = GetParam(); + + { + auto weights = op::v0::Constant::create(element::f32, + Shape{3, 1, 2, 2}, + {-0.01448f, + -0.02314f, + -0.02244f, + -0.00090f, + 0.024261f, + 0.031921f, + 0.034088f, + -0.0497f, + -0.0588f, + -0.04541f, + -0.01281f, + 0.009109f}); + auto scale = op::v0::Constant::create(element::f32, Shape{3, 1, 1, 1}, {54.50976f}); + auto shift = op::v0::Constant::create(element::f32, Shape{3, 1, 1, 1}, {0.7f, -0.0304f, -0.012f}); + auto fake_convert = std::make_shared(weights, scale, shift, destination_type); + model = std::make_shared(fake_convert, ParameterVector{}); + + manager.register_pass(); + } + + { + // TODO: change when it's allowed to create a fp8 constant from fp32 values + std::vector weights_data = destination_type == "f8e4m3" ? std::vector{0xbc, // -1.5 + 0xc0, // -2 + 0xbf, // -1.875 + 0xb4, // -0.75 + 0x3b, // 1.375 + 0x3e, // 1.75 + 0x3f, // 1.875 + 0xc3, // -2.75 + 0xc5, // -3.25 + 0xc2, // -2.5 + 0xb3, // -0.6875 + 0x30} // 0.5 + : std::vector{0xbe, // -1.5 + 0xc0, // -2 + 0xc0, // -2 + 0xba, // -0.75 + 0x3d, // 1.25 + 0x3f, // 1.75 + 0x40, // 2 + 0xc1, // -2.5 + 0xc2, // -3 + 0xc1, // -2.5 + 0xb9, // -0.625 + 0x38}; // 0.5 + + auto weights = + std::make_shared(element::Type(destination_type), Shape{3, 1, 2, 2}, weights_data.data()); + auto convert = std::make_shared(weights, element::f32); + auto shift = op::v0::Constant::create(element::f32, Shape{3, 1, 1, 1}, {-0.7f, 0.0304f, 0.012f}); + auto subtract = std::make_shared(convert, shift); + auto scale = op::v0::Constant::create(element::f32, Shape{3, 1, 1, 1}, {1.0f / 54.50976f}); + auto multiply = std::make_shared(subtract, scale); + model_ref = std::make_shared(multiply, ParameterVector{}); + } + + m_abs_threshold = 1e-6f; + m_rel_threshold = 1e-6f; + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); +} + +INSTANTIATE_TEST_SUITE_P(CompressQuantizeWeights, CompressWeightsWithZeroPoint, testing::Values("f8e4m3", "f8e5m2")); diff --git a/src/core/include/openvino/op/constant.hpp b/src/core/include/openvino/op/constant.hpp index ce089b2e6f7819..11c578be92a405 100644 --- a/src/core/include/openvino/op/constant.hpp +++ b/src/core/include/openvino/op/constant.hpp @@ -376,6 +376,12 @@ class OPENVINO_API Constant : public Op { case Type_t::u64: cast_vector(rc, num_elements_to_cast); break; + case Type_t::f8e4m3: + cast_vector(rc, num_elements_to_cast); + break; + case Type_t::f8e5m2: + cast_vector(rc, num_elements_to_cast); + break; case Type_t::string: cast_vector(rc, num_elements_to_cast); break; diff --git a/src/core/src/pass/serialize.cpp b/src/core/src/pass/serialize.cpp index 142fb71b345fdd..28d5389326d1f4 100644 --- a/src/core/src/pass/serialize.cpp +++ b/src/core/src/pass/serialize.cpp @@ -754,6 +754,10 @@ std::string get_precision_name(const ov::element::Type& elem_type) { return "BOOL"; case ::ov::element::Type_t::nf4: return "NF4"; + case ::ov::element::Type_t::f8e4m3: + return "F8E4M3"; + case ::ov::element::Type_t::f8e5m2: + return "F8E5M2"; case ::ov::element::Type_t::string: return "STRING"; default: