From b3c2c386dc527f5b877680fbe1215b6a3e05968f Mon Sep 17 00:00:00 2001 From: Katarzyna Mitrus Date: Tue, 16 Jan 2024 15:26:43 +0100 Subject: [PATCH] [FP8] Implementation of FP8 element types (ov::element::f8e4m3 and ov::element::f8e5m2) (#21608) * FP8 element types init * Remove redundant union * Replace using ov * Update fundamental types * Update class name check * Update tests * Remove redundant sign * Expose f8 types in Python * Add to py to dtype map * Style alignment * Align python style * Update test values * Remove constexpr from_bits to fix warning * Add trivially construcitble asserts and common constrexpr * Align python tests opset * Update f8e4m3 <-> float conversion * f8e5m2 class update * Add f8e5m2 unit test * Add to string conversion tests * Rename class f8e4m3 -> float8_e4m3 * Rename f8e5m2 -> float8_e5m2 * Remove size() and to_string from float8 - size() can be replaced by compile time sizeof - to_string can be replaced by std::to_string() * float8 E5M2 remove unused constexpr value * Fix union initialization and ncc style rules * Fix test issues * Use NaN from std::numeric_limits instead macro - minor refactor of float8_e4m3 * Update nf4 usage in element_type.cpp * Sync openvino.style with master * Update f8e5m2 test --------- Co-authored-by: Raasz, Pawel --- .../ncc_naming_style/openvino.style | 4 +- src/bindings/c/include/openvino/c/ov_common.h | 4 +- src/bindings/c/src/ov_tensor.cpp | 4 +- .../python/src/pyopenvino/core/common.cpp | 28 +-- .../pyopenvino/graph/types/element_type.cpp | 2 + .../python/tests/test_graph/test_constant.py | 122 ++++++++++++ src/core/include/ngraph/type/element_type.hpp | 2 + .../openvino/core/type/element_type.hpp | 14 ++ .../core/type/element_type_traits.hpp | 10 + .../openvino/core/type/float8_e4m3.hpp | 157 ++++++++++++++++ .../openvino/core/type/float8_e5m2.hpp | 157 ++++++++++++++++ src/core/include/openvino/op/constant.hpp | 12 ++ .../openvino/reference/utils/type_util.hpp | 3 +- src/core/src/op/constant.cpp | 22 ++- src/core/src/op/convert.cpp | 5 +- src/core/src/pass/visualize_tree.cpp | 2 + src/core/src/type/element_type.cpp | 79 ++++---- src/core/src/type/float8_e4m3.cpp | 135 ++++++++++++++ src/core/src/type/float8_e5m2.cpp | 47 +++++ src/core/tests/element_type.cpp | 4 + src/core/tests/eval.cpp | 112 +++++++++-- src/core/tests/float8_e4m3.cpp | 175 +++++++++++++++++ src/core/tests/float8_e5m2.cpp | 176 ++++++++++++++++++ .../op_reference/base_reference_test.cpp | 16 ++ .../functional/op_reference/constant.cpp | 44 +++++ .../tests/functional/op_reference/convert.cpp | 157 ++++++++++++++++ .../include/common_test_utils/data_utils.hpp | 8 + 27 files changed, 1426 insertions(+), 75 deletions(-) create mode 100644 src/core/include/openvino/core/type/float8_e4m3.hpp create mode 100644 src/core/include/openvino/core/type/float8_e5m2.hpp create mode 100644 src/core/src/type/float8_e4m3.cpp create mode 100644 src/core/src/type/float8_e5m2.cpp create mode 100644 src/core/tests/float8_e4m3.cpp create mode 100644 src/core/tests/float8_e5m2.cpp diff --git a/cmake/developer_package/ncc_naming_style/openvino.style b/cmake/developer_package/ncc_naming_style/openvino.style index ebf9ef078d4ba7..6608795381e4a1 100644 --- a/cmake/developer_package/ncc_naming_style/openvino.style +++ b/cmake/developer_package/ncc_naming_style/openvino.style @@ -1,6 +1,6 @@ # custom OpenVINO values CppMethod: '^(operator\W+|[a-z_\d]+|signaling_NaN|quiet_NaN|OPENVINO_OP)$' -ClassName: '^([A-Z][\w]+|b?float16|numeric_limits|ngraph_error|stopwatch|unsupported_op)$' +ClassName: '^([A-Z][\w]+|b?float16|float8_e4m3|float8_e5m2|numeric_limits|ngraph_error|stopwatch|unsupported_op)$' StructName: '^([A-Z][\w]+|element_type_traits|hash|oi_pair|stat)$' FunctionName: '^(operator\W+|[a-z_\d]+)|PrintTo$' Namespace: '^([a-z\d_]*|InferenceEngine)$' @@ -18,7 +18,7 @@ VariableReference: '^\w+$' EnumName: '^[A-Z][\w]+$' # excepts element_type -EnumConstantName: '^([A-Z\d_]+|undefined|dynamic|boolean|bf16|f16|f32|f64|i4|i8|i16|i32|i64|u1|u4|u8|u16|u32|u64|nf4|string|asymmetric|align_corners|round_prefer_floor|round_prefer_ceil|floor|ceil|simple|nearest|linear|linear_onnx|cubic|area|scales|sizes|half_pixel|tf_half_pixel_for_nn|pytorch_half_pixel|asymetric)$' +EnumConstantName: '^([A-Z\d_]+|undefined|dynamic|boolean|bf16|f16|f32|f64|i4|i8|i16|i32|i64|u1|u4|u8|u16|u32|u64|nf4|f8e4m3|f8e5m2|string|asymmetric|align_corners|round_prefer_floor|round_prefer_ceil|floor|ceil|simple|nearest|linear|linear_onnx|cubic|area|scales|sizes|half_pixel|tf_half_pixel_for_nn|pytorch_half_pixel|asymetric)$' # TODO: align UsingDeclaration: '^.*$' TypedefName: '^.*$' diff --git a/src/bindings/c/include/openvino/c/ov_common.h b/src/bindings/c/include/openvino/c/ov_common.h index bbbf3dd35c2db1..f1be5750c3bfec 100644 --- a/src/bindings/c/include/openvino/c/ov_common.h +++ b/src/bindings/c/include/openvino/c/ov_common.h @@ -187,6 +187,8 @@ typedef enum { U32, //!< u32 element type U64, //!< u64 element type NF4, //!< nf4 element type + F8E4M3, //!< f8e4m3 element type + F8E5M3, //!< f8e5m2 element type } ov_element_type_e; /** @@ -210,4 +212,4 @@ ov_free(const char* content); * @ingroup ov_base_c_api */ OPENVINO_C_API(const char*) -ov_get_last_err_msg(); \ No newline at end of file +ov_get_last_err_msg(); diff --git a/src/bindings/c/src/ov_tensor.cpp b/src/bindings/c/src/ov_tensor.cpp index a3372583637ff5..d81ab949d6bffb 100644 --- a/src/bindings/c/src/ov_tensor.cpp +++ b/src/bindings/c/src/ov_tensor.cpp @@ -24,7 +24,9 @@ const std::map element_type_map = { {ov_element_type_e::U16, ov::element::u16}, {ov_element_type_e::U32, ov::element::u32}, {ov_element_type_e::U64, ov::element::u64}, - {ov_element_type_e::NF4, ov::element::nf4}}; + {ov_element_type_e::NF4, ov::element::nf4}, + {ov_element_type_e::F8E4M3, ov::element::f8e4m3}, + {ov_element_type_e::F8E5M3, ov::element::f8e5m2}}; inline ov_element_type_e find_ov_element_type_e(ov::element::Type type) { for (auto iter = element_type_map.begin(); iter != element_type_map.end(); iter++) { diff --git a/src/bindings/python/src/pyopenvino/core/common.cpp b/src/bindings/python/src/pyopenvino/core/common.cpp index a8603423ecdf0c..945a0c5f777e89 100644 --- a/src/bindings/python/src/pyopenvino/core/common.cpp +++ b/src/bindings/python/src/pyopenvino/core/common.cpp @@ -19,24 +19,16 @@ namespace type_helpers { const std::map& ov_type_to_dtype() { static const std::map ov_type_to_dtype_mapping = { - {ov::element::f16, py::dtype("float16")}, - {ov::element::bf16, py::dtype("float16")}, - {ov::element::f32, py::dtype("float32")}, - {ov::element::f64, py::dtype("float64")}, - {ov::element::i8, py::dtype("int8")}, - {ov::element::i16, py::dtype("int16")}, - {ov::element::i32, py::dtype("int32")}, - {ov::element::i64, py::dtype("int64")}, - {ov::element::u8, py::dtype("uint8")}, - {ov::element::u16, py::dtype("uint16")}, - {ov::element::u32, py::dtype("uint32")}, - {ov::element::u64, py::dtype("uint64")}, - {ov::element::boolean, py::dtype("bool")}, - {ov::element::u1, py::dtype("uint8")}, - {ov::element::u4, py::dtype("uint8")}, - {ov::element::nf4, py::dtype("uint8")}, - {ov::element::i4, py::dtype("int8")}, - {ov::element::string, py::dtype("bytes_")}, + {ov::element::f16, py::dtype("float16")}, {ov::element::bf16, py::dtype("float16")}, + {ov::element::f32, py::dtype("float32")}, {ov::element::f64, py::dtype("float64")}, + {ov::element::i8, py::dtype("int8")}, {ov::element::i16, py::dtype("int16")}, + {ov::element::i32, py::dtype("int32")}, {ov::element::i64, py::dtype("int64")}, + {ov::element::u8, py::dtype("uint8")}, {ov::element::u16, py::dtype("uint16")}, + {ov::element::u32, py::dtype("uint32")}, {ov::element::u64, py::dtype("uint64")}, + {ov::element::boolean, py::dtype("bool")}, {ov::element::u1, py::dtype("uint8")}, + {ov::element::u4, py::dtype("uint8")}, {ov::element::nf4, py::dtype("uint8")}, + {ov::element::i4, py::dtype("int8")}, {ov::element::f8e4m3, py::dtype("uint8")}, + {ov::element::f8e5m2, py::dtype("uint8")}, {ov::element::string, py::dtype("bytes_")}, }; return ov_type_to_dtype_mapping; } diff --git a/src/bindings/python/src/pyopenvino/graph/types/element_type.cpp b/src/bindings/python/src/pyopenvino/graph/types/element_type.cpp index 595d06fb073145..dc7c484012c53c 100644 --- a/src/bindings/python/src/pyopenvino/graph/types/element_type.cpp +++ b/src/bindings/python/src/pyopenvino/graph/types/element_type.cpp @@ -50,6 +50,8 @@ void regclass_graph_Type(py::module m) { type.attr("u64") = ov::element::u64; type.attr("bf16") = ov::element::bf16; type.attr("nf4") = ov::element::nf4; + type.attr("f8e4m3") = ov::element::f8e4m3; + type.attr("f8e5m2") = ov::element::f8e5m2; type.attr("string") = ov::element::string; type.def("__hash__", &ov::element::Type::hash); diff --git a/src/bindings/python/tests/test_graph/test_constant.py b/src/bindings/python/tests/test_graph/test_constant.py index 73a288e943b801..6452084e377f09 100644 --- a/src/bindings/python/tests/test_graph/test_constant.py +++ b/src/bindings/python/tests/test_graph/test_constant.py @@ -411,3 +411,125 @@ def test_memory_sharing(shared_flag): else: assert not np.array_equal(ov_const.data, arr) assert not np.shares_memory(arr, ov_const.data) + + +@pytest.mark.parametrize(("ov_type", "numpy_dtype"), [ + (Type.f32, np.float32), + (Type.f16, np.float16), +]) +def test_float_to_f8e5m2_constant(ov_type, numpy_dtype): + from openvino.runtime import opset12 as opset + import openvino as ov + data = np.array([4.75, 4.5, -5.25, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, + 0.6, 0.7, 0.8, 0.9, 1, -0.0, -0.1, -0.2, -0.3, + -0.4, -0.5, -0.6, -0.7, -0.8, -0.9, -1.0, 0.0000152587890625, 448, 500, 512, 57344], dtype=numpy_dtype) + + compressed_const = opset.constant(data, dtype=ov.Type.f8e5m2, name="f8e5m2_constant") + convert = opset.convert(compressed_const, data.dtype) + parameter = opset.parameter(ov.PartialShape([-1]), ov_type) + add_op = opset.add(parameter, convert) + model = ov.Model([add_op], [parameter]) + + compiled = ov.compile_model(model) + tensor = np.zeros(data.shape, dtype=numpy_dtype) + result = compiled(tensor)[0] + + target = [5.0, 4.0, -5.0, 0.0, 0.09375, 0.1875, 0.3125, 0.375, 0.5, 0.625, 0.75, + 0.75, 0.875, 1.0, -0.0, -0.09375, -0.1875, -0.3125, -0.375, + -0.5, -0.625, -0.75, -0.75, -0.875, -1.0, 0.0000152587890625, + 448, 512, 512, 57344] + target = np.array(target, dtype=numpy_dtype) + + assert np.allclose(result, target) + + +@pytest.mark.parametrize(("ov_type", "numpy_dtype"), [ + (Type.f32, np.float32), + (Type.f16, np.float16), +]) +def test_float_to_f8e4m3_constant(ov_type, numpy_dtype): + from openvino.runtime import opset12 as opset + import openvino as ov + data = np.array([4.75, 4.5, -5.25, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, + 0.6, 0.7, 0.8, 0.9, 1, -0.0, -0.1, -0.2, -0.3, + -0.4, -0.5, -0.6, -0.7, -0.8, -0.9, -1, 448, 512], dtype=numpy_dtype) + + compressed_const = opset.constant(data, dtype=ov.Type.f8e4m3, name="f8e4m3_constant") + convert = opset.convert(compressed_const, data.dtype) + parameter = opset.parameter(ov.PartialShape([-1]), ov_type) + add_op = opset.add(parameter, convert) + model = ov.Model([add_op], [parameter]) + + compiled = ov.compile_model(model) + tensor = np.zeros(data.shape, dtype=numpy_dtype) + result = compiled(tensor)[0] + + target = [5.0, 4.5, -5.0, 0.0, 0.1015625, 0.203125, 0.3125, + 0.40625, 0.5, 0.625, 0.6875, 0.8125, 0.875, 1, + -0, -0.1015625, -0.203125, -0.3125, -0.40625, -0.5, -0.625, + -0.6875, -0.8125, -0.875, -1, 448, np.nan] + target = np.array(target, dtype=numpy_dtype) + + assert np.allclose(result, target, equal_nan=True) + + +@pytest.mark.parametrize(("ov_type", "numpy_dtype"), [ + (Type.f32, np.float32), + (Type.f16, np.float16), +]) +def test_float_to_f8e5m2_convert(ov_type, numpy_dtype): + from openvino.runtime import opset12 as opset + import openvino as ov + data = np.array([4.75, 4.5, -5.25, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, + 0.6, 0.7, 0.8, 0.9, 1, -0.0, -0.1, -0.2, -0.3, + -0.4, -0.5, -0.6, -0.7, -0.8, -0.9, -1.0, 0.0000152587890625, 448, 500, 512, 57344], dtype=numpy_dtype) + + compressed_const = opset.constant(data, dtype=ov_type, name="fx_constant") + convert_to_fp8 = opset.convert(compressed_const, Type.f8e5m2) + convert_back = opset.convert(convert_to_fp8, ov_type) + parameter = opset.parameter(ov.PartialShape([-1]), ov_type) + add_op = opset.add(parameter, convert_back) + model = ov.Model([add_op], [parameter]) + + compiled = ov.compile_model(model) + tensor = np.zeros(data.shape, dtype=numpy_dtype) + result = compiled(tensor)[0] + + target = [5.0, 4.0, -5.0, 0.0, 0.09375, 0.1875, 0.3125, 0.375, 0.5, 0.625, 0.75, + 0.75, 0.875, 1.0, -0.0, -0.09375, -0.1875, -0.3125, -0.375, + -0.5, -0.625, -0.75, -0.75, -0.875, -1.0, 0.0000152587890625, + 448, 512, 512, 57344] + target = np.array(target, dtype=numpy_dtype) + + assert np.allclose(result, target) + + +@pytest.mark.parametrize(("ov_type", "numpy_dtype"), [ + (Type.f32, np.float32), + (Type.f16, np.float16), +]) +def test_float_to_f8e4m3_convert(ov_type, numpy_dtype): + from openvino.runtime import opset12 as opset + import openvino as ov + data = np.array([4.75, 4.5, -5.25, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, + 0.6, 0.7, 0.8, 0.9, 1, -0.0, -0.1, -0.2, -0.3, + -0.4, -0.5, -0.6, -0.7, -0.8, -0.9, -1, 448, 512], dtype=numpy_dtype) + + compressed_const = opset.constant(data, dtype=ov_type, name="fx_constant") + convert_to_fp8 = opset.convert(compressed_const, Type.f8e4m3) + convert_back = opset.convert(convert_to_fp8, ov_type) + parameter = opset.parameter(ov.PartialShape([-1]), ov_type) + add_op = opset.add(parameter, convert_back) + model = ov.Model([add_op], [parameter]) + + compiled = ov.compile_model(model) + tensor = np.zeros(data.shape, dtype=numpy_dtype) + result = compiled(tensor)[0] + + target = [5.0, 4.5, -5.0, 0.0, 0.1015625, 0.203125, 0.3125, + 0.40625, 0.5, 0.625, 0.6875, 0.8125, 0.875, 1, + -0, -0.1015625, -0.203125, -0.3125, -0.40625, -0.5, -0.625, + -0.6875, -0.8125, -0.875, -1, 448, np.nan] + target = np.array(target, dtype=numpy_dtype) + + assert np.allclose(result, target, equal_nan=True) diff --git a/src/core/include/ngraph/type/element_type.hpp b/src/core/include/ngraph/type/element_type.hpp index cd125409db5bc6..3ff94063d82d65 100644 --- a/src/core/include/ngraph/type/element_type.hpp +++ b/src/core/include/ngraph/type/element_type.hpp @@ -35,6 +35,8 @@ using ov::element::dynamic; using ov::element::f16; using ov::element::f32; using ov::element::f64; +using ov::element::f8e4m3; +using ov::element::f8e5m2; using ov::element::i16; using ov::element::i32; using ov::element::i4; diff --git a/src/core/include/openvino/core/type/element_type.hpp b/src/core/include/openvino/core/type/element_type.hpp index 88e79a75d25174..bea57a6ce98479 100644 --- a/src/core/include/openvino/core/type/element_type.hpp +++ b/src/core/include/openvino/core/type/element_type.hpp @@ -20,6 +20,8 @@ #include "openvino/core/rtti.hpp" #include "openvino/core/type/bfloat16.hpp" #include "openvino/core/type/float16.hpp" +#include "openvino/core/type/float8_e4m3.hpp" +#include "openvino/core/type/float8_e5m2.hpp" /** * @defgroup ov_element_cpp_api Element types @@ -52,6 +54,8 @@ enum class Type_t { u32, //!< u32 element type u64, //!< u64 element type nf4, //!< nf4 element type + f8e4m3, //!< f8e4m3 element type + f8e5m2, //!< f8e5m2 element type string //!< string element type }; @@ -182,6 +186,12 @@ constexpr Type u64(Type_t::u64); /// \brief nf4 element type /// \ingroup ov_element_cpp_api constexpr Type nf4(Type_t::nf4); +/// \brief f8e4m3 element type +/// \ingroup ov_element_cpp_api +constexpr Type f8e4m3(Type_t::f8e4m3); +/// \brief f8e4m3 element type +/// \ingroup ov_element_cpp_api +constexpr Type f8e5m2(Type_t::f8e5m2); /// \brief string element type /// \ingroup ov_element_cpp_api constexpr Type string(Type_t::string); @@ -219,6 +229,10 @@ OPENVINO_API Type from(); template <> OPENVINO_API Type from(); template <> +OPENVINO_API Type from(); +template <> +OPENVINO_API Type from(); +template <> OPENVINO_API Type from(); OPENVINO_API Type fundamental_type_for(const Type& type); diff --git a/src/core/include/openvino/core/type/element_type_traits.hpp b/src/core/include/openvino/core/type/element_type_traits.hpp index 33f0bbd059a99d..fefbac51866417 100644 --- a/src/core/include/openvino/core/type/element_type_traits.hpp +++ b/src/core/include/openvino/core/type/element_type_traits.hpp @@ -98,6 +98,16 @@ struct element_type_traits { using value_type = int8_t; }; +template <> +struct element_type_traits { + using value_type = ov::float8_e4m3; +}; + +template <> +struct element_type_traits { + using value_type = ov::float8_e5m2; +}; + template <> struct element_type_traits { using value_type = std::string; diff --git a/src/core/include/openvino/core/type/float8_e4m3.hpp b/src/core/include/openvino/core/type/float8_e4m3.hpp new file mode 100644 index 00000000000000..af95d183d69129 --- /dev/null +++ b/src/core/include/openvino/core/type/float8_e4m3.hpp @@ -0,0 +1,157 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "openvino/core/core_visibility.hpp" + +namespace ov { + +/** + * @brief Class to represent the f8e4m3 type. + */ +class OPENVINO_API float8_e4m3 { +public: + float8_e4m3() = default; + float8_e4m3(uint32_t sign, uint32_t biased_exponent, uint32_t fraction); + float8_e4m3(float value); + + template + explicit float8_e4m3(I value) : m_value{float8_e4m3{static_cast(value)}.m_value} {} + + template + bool operator==(const T& other) const; + template + bool operator!=(const T& other) const { + return !(*this == other); + } + + template + bool operator<(const T& other) const; + template + bool operator<=(const T& other) const; + template + bool operator>(const T& other) const; + template + bool operator>=(const T& other) const; + template + float8_e4m3 operator+(const T& other) const; + template + float8_e4m3 operator+=(const T& other); + template + float8_e4m3 operator-(const T& other) const; + template + float8_e4m3 operator-=(const T& other); + template + float8_e4m3 operator*(const T& other) const; + template + float8_e4m3 operator*=(const T& other); + template + float8_e4m3 operator/(const T& other) const; + template + float8_e4m3 operator/=(const T& other); + + operator float() const; + + static constexpr float8_e4m3 from_bits(uint8_t bits) { + return float8_e4m3(bits, true); + } + uint8_t to_bits() const; + friend std::ostream& operator<<(std::ostream& out, const float8_e4m3& obj) { + out << static_cast(obj); + return out; + } + +private: + constexpr float8_e4m3(uint8_t x, bool) : m_value{x} {} + + uint8_t m_value; +}; + +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable : 4756) +#endif +template +bool float8_e4m3::operator==(const T& other) const { +#if defined(__GNUC__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wfloat-equal" +#endif + return (static_cast(*this) == static_cast(other)); +#if defined(__GNUC__) +# pragma GCC diagnostic pop +#endif +} + +template +bool float8_e4m3::operator<(const T& other) const { + return (static_cast(*this) < static_cast(other)); +} + +template +bool float8_e4m3::operator<=(const T& other) const { + return (static_cast(*this) <= static_cast(other)); +} + +template +bool float8_e4m3::operator>(const T& other) const { + return (static_cast(*this) > static_cast(other)); +} + +template +bool float8_e4m3::operator>=(const T& other) const { + return (static_cast(*this) >= static_cast(other)); +} + +template +float8_e4m3 float8_e4m3::operator+(const T& other) const { + return {static_cast(*this) + static_cast(other)}; +} + +template +float8_e4m3 float8_e4m3::operator+=(const T& other) { + return *this = *this + other; +} + +template +float8_e4m3 float8_e4m3::operator-(const T& other) const { + return {static_cast(*this) - static_cast(other)}; +} + +template +float8_e4m3 float8_e4m3::operator-=(const T& other) { + return *this = *this - other; +} + +template +float8_e4m3 float8_e4m3::operator*(const T& other) const { + return {static_cast(*this) * static_cast(other)}; +} + +template +float8_e4m3 float8_e4m3::operator*=(const T& other) { + return *this = *this * other; +} + +template +float8_e4m3 float8_e4m3::operator/(const T& other) const { + return {static_cast(*this) / static_cast(other)}; +} + +template +float8_e4m3 float8_e4m3::operator/=(const T& other) { + return *this = *this / other; +} +#if defined(_MSC_VER) +# pragma warning(pop) +#endif +} // namespace ov diff --git a/src/core/include/openvino/core/type/float8_e5m2.hpp b/src/core/include/openvino/core/type/float8_e5m2.hpp new file mode 100644 index 00000000000000..e3990de0c56169 --- /dev/null +++ b/src/core/include/openvino/core/type/float8_e5m2.hpp @@ -0,0 +1,157 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "openvino/core/core_visibility.hpp" + +namespace ov { + +/** + * @brief Class to represent the f8e5m2 type. + */ +class OPENVINO_API float8_e5m2 { +public: + float8_e5m2() = default; + float8_e5m2(uint32_t sign, uint32_t biased_exponent, uint32_t fraction); + float8_e5m2(float value); + + template + explicit float8_e5m2(I value) : m_value{float8_e5m2{static_cast(value)}.m_value} {} + + template + bool operator==(const T& other) const; + template + bool operator!=(const T& other) const { + return !(*this == other); + } + + template + bool operator<(const T& other) const; + template + bool operator<=(const T& other) const; + template + bool operator>(const T& other) const; + template + bool operator>=(const T& other) const; + template + float8_e5m2 operator+(const T& other) const; + template + float8_e5m2 operator+=(const T& other); + template + float8_e5m2 operator-(const T& other) const; + template + float8_e5m2 operator-=(const T& other); + template + float8_e5m2 operator*(const T& other) const; + template + float8_e5m2 operator*=(const T& other); + template + float8_e5m2 operator/(const T& other) const; + template + float8_e5m2 operator/=(const T& other); + + operator float() const; + + static constexpr float8_e5m2 from_bits(uint8_t bits) { + return float8_e5m2(bits, true); + } + uint8_t to_bits() const; + friend std::ostream& operator<<(std::ostream& out, const float8_e5m2& obj) { + out << static_cast(obj); + return out; + } + +private: + constexpr float8_e5m2(uint8_t x, bool) : m_value{x} {} + + uint8_t m_value; +}; + +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable : 4756) +#endif +template +bool float8_e5m2::operator==(const T& other) const { +#if defined(__GNUC__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wfloat-equal" +#endif + return (static_cast(*this) == static_cast(other)); +#if defined(__GNUC__) +# pragma GCC diagnostic pop +#endif +} + +template +bool float8_e5m2::operator<(const T& other) const { + return (static_cast(*this) < static_cast(other)); +} + +template +bool float8_e5m2::operator<=(const T& other) const { + return (static_cast(*this) <= static_cast(other)); +} + +template +bool float8_e5m2::operator>(const T& other) const { + return (static_cast(*this) > static_cast(other)); +} + +template +bool float8_e5m2::operator>=(const T& other) const { + return (static_cast(*this) >= static_cast(other)); +} + +template +float8_e5m2 float8_e5m2::operator+(const T& other) const { + return {static_cast(*this) + static_cast(other)}; +} + +template +float8_e5m2 float8_e5m2::operator+=(const T& other) { + return *this = *this + other; +} + +template +float8_e5m2 float8_e5m2::operator-(const T& other) const { + return {static_cast(*this) - static_cast(other)}; +} + +template +float8_e5m2 float8_e5m2::operator-=(const T& other) { + return *this = *this - other; +} + +template +float8_e5m2 float8_e5m2::operator*(const T& other) const { + return {static_cast(*this) * static_cast(other)}; +} + +template +float8_e5m2 float8_e5m2::operator*=(const T& other) { + return *this = *this * other; +} + +template +float8_e5m2 float8_e5m2::operator/(const T& other) const { + return {static_cast(*this) / static_cast(other)}; +} + +template +float8_e5m2 float8_e5m2::operator/=(const T& other) { + return *this = *this / other; +} +#if defined(_MSC_VER) +# pragma warning(pop) +#endif +} // namespace ov diff --git a/src/core/include/openvino/op/constant.hpp b/src/core/include/openvino/op/constant.hpp index fe91da44baf6f4..ce089b2e6f7819 100644 --- a/src/core/include/openvino/op/constant.hpp +++ b/src/core/include/openvino/op/constant.hpp @@ -164,6 +164,12 @@ class OPENVINO_API Constant : public Op { case Type_t::nf4: fill_data(value); break; + case Type_t::f8e4m3: + fill_data(value); + break; + case Type_t::f8e5m2: + fill_data(value); + break; case Type_t::string: fill_data(value); break; @@ -882,6 +888,12 @@ class OPENVINO_API Constant : public Op { case Type_t::nf4: write_buffer(source); break; + case Type_t::f8e4m3: + write_buffer(source); + break; + case Type_t::f8e5m2: + write_buffer(source); + break; case Type_t::string: write_buffer(source); break; diff --git a/src/core/reference/include/openvino/reference/utils/type_util.hpp b/src/core/reference/include/openvino/reference/utils/type_util.hpp index 12291761612340..10d513d2fb8b3a 100644 --- a/src/core/reference/include/openvino/reference/utils/type_util.hpp +++ b/src/core/reference/include/openvino/reference/utils/type_util.hpp @@ -19,6 +19,7 @@ namespace ov { template constexpr bool is_floating_point() { using U = typename std::decay::type; - return std::is_floating_point::value || std::is_same::value || std::is_same::value; + return std::is_floating_point::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value; } } // namespace ov diff --git a/src/core/src/op/constant.cpp b/src/core/src/op/constant.cpp index 914324a5dc97c6..557c71d014f74d 100644 --- a/src/core/src/op/constant.cpp +++ b/src/core/src/op/constant.cpp @@ -221,8 +221,26 @@ struct ValueToString : ov::element::NotSupported { std::string Constant::convert_value_to_string(size_t index) const { using namespace ov::element; - return IfTypeOf::apply< - ValueToString>(get_element_type(), this, index); + return IfTypeOf::apply(get_element_type(), this, index); } size_t Constant::get_byte_size() const { diff --git a/src/core/src/op/convert.cpp b/src/core/src/op/convert.cpp index b48f2d5e433ba7..a79922dcacd42f 100644 --- a/src/core/src/op/convert.cpp +++ b/src/core/src/op/convert.cpp @@ -19,7 +19,8 @@ constexpr bool is_lp_type(const element::Type_t et) { return (et == element::i4) || (et == element::u1) || (et == element::u4) || (et == element::nf4); } -#define CONVERT_ET_LIST boolean, bf16, f16, f32, f64, i4, i8, i16, i32, i64, u1, u4, u8, u16, u32, u64, nf4 +#define CONVERT_ET_LIST \ + boolean, bf16, f16, f32, f64, i4, i8, i16, i32, i64, u1, u4, u8, u16, u32, u64, nf4, f8e4m3, f8e5m2 struct Evaluate : public element::NoAction { using element::NoAction::visit; @@ -173,6 +174,8 @@ bool Convert::has_evaluate() const { case element::u32: case element::u64: case element::nf4: + case element::f8e4m3: + case element::f8e5m2: return true; default: return false; diff --git a/src/core/src/pass/visualize_tree.cpp b/src/core/src/pass/visualize_tree.cpp index e981f7c4c95911..bcd3bcd1713390 100644 --- a/src/core/src/pass/visualize_tree.cpp +++ b/src/core/src/pass/visualize_tree.cpp @@ -376,6 +376,8 @@ static std::string get_value(const std::shared_ptr& consta case ov::element::Type_t::u4: case ov::element::Type_t::nf4: case ov::element::Type_t::i4: + case ov::element::Type_t::f8e4m3: + case ov::element::Type_t::f8e5m2: ss << constant->get_output_element_type(0).get_type_name() << " value"; break; case ov::element::Type_t::bf16: diff --git a/src/core/src/type/element_type.cpp b/src/core/src/type/element_type.cpp index e5607e6a60a4de..a49312b6530378 100644 --- a/src/core/src/type/element_type.cpp +++ b/src/core/src/type/element_type.cpp @@ -71,6 +71,10 @@ inline TypeInfo get_type_info(ov::element::Type_t type) { return {64, false, false, false, "uint64_t", "u64"}; case ov::element::Type_t::nf4: return {4, false, false, true, "nfloat4", "nf4"}; + case ov::element::Type_t::f8e4m3: + return {8, true, true, true, "f8e4m3", "f8e4m3"}; + case ov::element::Type_t::f8e5m2: + return {8, true, true, true, "f8e5m2", "f8e5m2"}; case ov::element::Type_t::string: return {8 * sizeof(std::string), false, false, false, "string", "string"}; default: @@ -119,6 +123,10 @@ ov::element::Type type_from_string(const std::string& type) { return ::ov::element::Type(::ov::element::Type_t::dynamic); } else if (type == "nf4" || type == "NF4") { return ::ov::element::Type(::ov::element::Type_t::nf4); + } else if (type == "f8e4m3" || type == "F8E4M3") { + return ::ov::element::Type(::ov::element::Type_t::f8e4m3); + } else if (type == "f8e5m2" || type == "F8E5M2") { + return ::ov::element::Type(::ov::element::Type_t::f8e5m2); } else { OPENVINO_THROW("Incorrect type: ", type); } @@ -126,24 +134,12 @@ ov::element::Type type_from_string(const std::string& type) { } // namespace std::vector ov::element::Type::get_known_types() { - std::vector rc = {&ov::element::dynamic, - &ov::element::boolean, - &ov::element::bf16, - &ov::element::f16, - &ov::element::f32, - &ov::element::f64, - &ov::element::i4, - &ov::element::i8, - &ov::element::i16, - &ov::element::i32, - &ov::element::i64, - &ov::element::u1, - &ov::element::u4, - &ov::element::u8, - &ov::element::u16, - &ov::element::u32, - &ov::element::u64, - &ov::element::string}; + std::vector rc = { + &ov::element::dynamic, &ov::element::boolean, &ov::element::bf16, &ov::element::f16, &ov::element::f32, + &ov::element::f64, &ov::element::i4, &ov::element::i8, &ov::element::i16, &ov::element::i32, + &ov::element::i64, &ov::element::u1, &ov::element::u4, &ov::element::u8, &ov::element::u16, + &ov::element::u32, &ov::element::u64, &ov::element::nf4, &ov::element::f8e4m3, &ov::element::f8e5m2, + &ov::element::string}; return rc; } @@ -172,7 +168,9 @@ ov::element::Type::Type(size_t bitwidth, {ov::element::Type_t::u16, {16, false, false, false, "uint16_t", "u16"}}, {ov::element::Type_t::u32, {32, false, false, false, "uint32_t", "u32"}}, {ov::element::Type_t::u64, {64, false, false, false, "uint64_t", "u64"}}, - {ov::element::Type_t::u4, {4, false, false, false, "uint4_t", "nf4"}}, + {ov::element::Type_t::nf4, {4, false, false, true, "nfloat4", "nf4"}}, + {ov::element::Type_t::f8e4m3, {8, true, true, true, "f8e4m3", "f8e4m3"}}, + {ov::element::Type_t::f8e5m2, {8, true, true, true, "f8e5m2", "f8e5m2"}}, {ov::element::Type_t::string, {8 * sizeof(std::string), false, false, false, "string", "string"}}, }; for (const auto& t : elements_map) { @@ -266,6 +264,14 @@ Type from() { return Type_t::bf16; } template <> +Type from() { + return Type_t::f8e4m3; +} +template <> +Type from() { + return Type_t::f8e5m2; +} +template <> Type from() { return Type_t::string; } @@ -282,6 +288,10 @@ Type fundamental_type_for(const Type& type) { return from::value_type>(); case Type_t::f64: return from::value_type>(); + case Type_t::f8e4m3: + return from::value_type>(); + case Type_t::f8e5m2: + return from::value_type>(); case Type_t::i4: return from::value_type>(); case Type_t::i8: @@ -304,6 +314,8 @@ Type fundamental_type_for(const Type& type) { return from::value_type>(); case Type_t::u64: return from::value_type>(); + case Type_t::nf4: + return from::value_type>(); case Type_t::string: return from::value_type>(); default: @@ -320,24 +332,13 @@ std::ostream& ov::element::operator<<(std::ostream& out, const ov::element::Type std::istream& ov::element::operator>>(std::istream& in, ov::element::Type& obj) { const std::unordered_map legacy = { - {"BOOL", ov::element::boolean}, - {"BF16", ov::element::bf16}, - {"I4", ov::element::i4}, - {"I8", ov::element::i8}, - {"I16", ov::element::i16}, - {"I32", ov::element::i32}, - {"I64", ov::element::i64}, - {"U4", ov::element::u4}, - {"U8", ov::element::u8}, - {"U16", ov::element::u16}, - {"U32", ov::element::u32}, - {"U64", ov::element::u64}, - {"FP32", ov::element::f32}, - {"FP64", ov::element::f64}, - {"FP16", ov::element::f16}, - {"BIN", ov::element::u1}, - {"NF4", ov::element::nf4}, - {"STRING", ov::element::string}, + {"BOOL", ov::element::boolean}, {"BF16", ov::element::bf16}, {"I4", ov::element::i4}, + {"I8", ov::element::i8}, {"I16", ov::element::i16}, {"I32", ov::element::i32}, + {"I64", ov::element::i64}, {"U4", ov::element::u4}, {"U8", ov::element::u8}, + {"U16", ov::element::u16}, {"U32", ov::element::u32}, {"U64", ov::element::u64}, + {"FP32", ov::element::f32}, {"FP64", ov::element::f64}, {"FP16", ov::element::f16}, + {"BIN", ov::element::u1}, {"NF4", ov::element::nf4}, {"F8E4M3", ov::element::f8e4m3}, + {"F8E5M2", ov::element::f8e5m2}, {"STRING", ov::element::string}, }; std::string str; in >> str; @@ -420,6 +421,8 @@ inline size_t compiler_byte_size(ov::element::Type_t et) { ET_CASE(u32); ET_CASE(u64); ET_CASE(nf4); + ET_CASE(f8e4m3); + ET_CASE(f8e5m2); ET_CASE(string); #undef ET_CASE case ov::element::Type_t::undefined: @@ -454,6 +457,8 @@ OPENVINO_API EnumNames& EnumNames::get() { {"u32", element::Type_t::u32}, {"u64", element::Type_t::u64}, {"nf4", element::Type_t::nf4}, + {"f8e4m3", element::Type_t::f8e4m3}, + {"f8e5m2", element::Type_t::f8e5m2}, {"string", element::Type_t::string}}); return enum_names; } diff --git a/src/core/src/type/float8_e4m3.cpp b/src/core/src/type/float8_e4m3.cpp new file mode 100644 index 00000000000000..9041b8a0070497 --- /dev/null +++ b/src/core/src/type/float8_e4m3.cpp @@ -0,0 +1,135 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/core/type/float8_e4m3.hpp" + +#include +#include +#include + +namespace ov { + +static_assert(sizeof(float8_e4m3) == 1, "class f8e4m3 must be exactly 1 byte"); +static_assert(std::is_trivially_constructible::value, "should be trivially constructible"); +static_assert(std::is_trivially_copyable::value, "must be trivially copyable"); +static_assert(std::is_trivially_destructible::value, "must be trivially destructible"); + +namespace { +constexpr auto float_nan = std::numeric_limits::quiet_NaN(); +// Lookup table for conversion f8 -> float. The f8 bit value without sign bit (masked 0x7f) is LUT offset. +static constexpr std::array f8_to_float_lut{ + 0.0f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f, + 0.015625f, 0.017578125f, 0.01953125f, 0.021484375f, 0.0234375f, 0.025390625f, 0.02734375f, 0.029296875f, + 0.03125f, 0.03515625f, 0.0390625f, 0.04296875f, 0.046875f, 0.05078125f, 0.0546875f, 0.05859375f, + 0.0625f, 0.0703125f, 0.078125f, 0.0859375f, 0.09375f, 0.1015625f, 0.109375f, 0.1171875f, + 0.125f, 0.140625f, 0.15625f, 0.171875f, 0.1875f, 0.203125f, 0.21875f, 0.234375f, + 0.25f, 0.28125f, 0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f, 0.46875f, + 0.5f, 0.5625f, 0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f, 0.9375f, + 1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f, + 2.0f, 2.25f, 2.5f, 2.75f, 3.0f, 3.25f, 3.5f, 3.75f, + 4.0f, 4.5f, 5.0f, 5.5f, 6.0f, 6.5f, 7.0f, 7.5f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f, + 32.0f, 36.0f, 40.0f, 44.0f, 48.0f, 52.0f, 56.0f, 60.0f, + 64.0f, 72.0f, 80.0f, 88.0f, 96.0f, 104.0f, 112.0f, 120.0f, + 128.0f, 144.0f, 160.0f, 176.0f, 192.0f, 208.0f, 224.0f, 240.0f, + 256.0f, 288.0f, 320.0f, 352.0f, 384.0f, 416.0f, 448.0f, float_nan}; + +constexpr uint32_t three_bytes_shift = 24; + +constexpr uint8_t f8e4m3_s_mask = 0x80; // f8e4m3 sign bit mask +constexpr uint8_t f8e4m3_e_size = 4; // f8e4m3 exponent bit size +constexpr uint8_t f8e4m3_e_mask = 0x78; // f8e4m3 exponent bit mask +constexpr uint8_t f8e4m3_e_bias = 7; // f8e4m3 exponent bias +constexpr uint8_t f8e4m3_e_max = 0x0f; // f8e4m3 exponent max value +constexpr uint8_t f8e4m3_m_size = 3; // f8e4m3 mantissa bits size +constexpr uint8_t f8e4m3_m_mask = 0x07; // f8e4m3 mantissa bit mask + +union f32_t { + float value; + uint32_t bits; +}; + +uint8_t f32_to_f8e4m3_bits(const float value) { + constexpr uint32_t f32_s_mask = 0x80000000; // f32 sign bit mask + constexpr uint32_t f32_e_mask = 0x7F800000; // f32 exponent bits mask + constexpr uint32_t f32_e_bias = 127; // f32 exponent bias + constexpr uint32_t f32_e_size = 8; // f32 exponent bits size + constexpr uint32_t f32_m_mask = 0x007fffff; // f32 mantissa bits mask + constexpr uint32_t f32_m_size = 23; // f32 mantissa bits size + + constexpr uint32_t f8_e_mask = f8e4m3_e_mask << three_bytes_shift; // f8 exponent bits mask (on u32) + constexpr uint32_t f8_m_mask = f8e4m3_m_mask << three_bytes_shift; // f8 mantissa bits mask (on u32) + constexpr uint32_t f8_m_hidden_one_mask = 0x08000000; // f8 mantissa hidden one bits mask (on u32) + + constexpr uint32_t round_half = 0x01ffffff; // value for half to even round for f8 + constexpr uint32_t round_norm = 0x007fffff; // value for normal round for f8 + constexpr uint32_t round_even = 0x00800000; // value for half to even round for f8 + constexpr uint32_t round_odd = 0x01800000; // value for an non-half to even round for f8 + + const auto input = f32_t{value}; + auto f8_bits = static_cast((input.bits & f32_s_mask) >> three_bytes_shift); + + uint32_t f32_e_field = input.bits & f32_e_mask; + + if (f32_e_field == f32_e_mask) { + f8_bits |= (f8e4m3_e_mask | f8e4m3_m_mask); + } else if (f32_e_field != 0) { + int32_t f8_biased_exp = (f32_e_field >> f32_m_size) - (f32_e_bias - f8e4m3_e_bias); + uint32_t fractional = (input.bits & f32_m_mask) << (f32_e_size - f8e4m3_e_size); + + // for normalized values round apply rounding change f8 fractional and biased exponent + if ((fractional & round_half) == round_odd || (fractional & round_norm) != 0) { + fractional += round_even; + if (0 != (fractional & f8_e_mask)) { + fractional &= f8_e_mask; + ++f8_biased_exp; + } + } + fractional &= f8_m_mask; + + // set exponent and mantissa on f8 bits + if (f8_biased_exp > f8e4m3_e_max) { + // Use NAN as this type has no infinity + f8_bits |= (f8e4m3_e_mask | f8e4m3_m_mask); + } else if (f8_biased_exp > 0) { + f8_bits |= (f8_biased_exp << f8e4m3_m_size) | (fractional >> three_bytes_shift); + } else { + // Restore the hidden 1 in f8 mantissa for subnormal calculation + fractional = f8_m_hidden_one_mask | (input.bits & f32_m_mask) << (f32_e_size - f8e4m3_e_size); + // Will any bits be shifted off? + int32_t shift = f8_biased_exp < -(f8e4m3_e_max) ? 0 : (1U << (1 - f8_biased_exp)); + uint32_t sticky = (fractional & (shift - 1)) ? 1 : 0; + + fractional = ((1 + f8_biased_exp) > f8e4m3_e_max) ? 0 : fractional >> (1 - f8_biased_exp); + fractional |= sticky; + // apply rounding + if (((fractional & round_half) == round_odd) || ((fractional & round_norm) != 0)) { + fractional += round_even; + } + + f8_bits |= fractional >> three_bytes_shift; + } + } + + return f8_bits; +} +} // namespace + +float8_e4m3::float8_e4m3(const uint32_t sign, const uint32_t biased_exponent, const uint32_t fraction) + : m_value(((sign & 0x01U) << (f8e4m3_e_size + f8e4m3_m_size)) | + (biased_exponent & (f8e4m3_e_mask >> f8e4m3_m_size)) << f8e4m3_m_size | (fraction & f8e4m3_m_mask)) {} + +float8_e4m3::float8_e4m3(const float value) : m_value{f32_to_f8e4m3_bits(value)} {} + +float8_e4m3::operator float() const { + auto converted = f32_t{f8_to_float_lut[m_value & (f8e4m3_e_mask | f8e4m3_m_mask)]}; + converted.bits |= (m_value & f8e4m3_s_mask) << three_bytes_shift; + return converted.value; +} + +uint8_t float8_e4m3::to_bits() const { + return m_value; +} +} // namespace ov diff --git a/src/core/src/type/float8_e5m2.cpp b/src/core/src/type/float8_e5m2.cpp new file mode 100644 index 00000000000000..b44a0f75b21948 --- /dev/null +++ b/src/core/src/type/float8_e5m2.cpp @@ -0,0 +1,47 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/core/type/float8_e5m2.hpp" + +#include +#include + +#include "openvino/reference/fake_convert.hpp" + +namespace ov { +static_assert(sizeof(float8_e5m2) == 1, "class f8e5m2 must be exactly 1 byte"); +static_assert(std::is_trivially_constructible::value, "should be trivially constructible"); +static_assert(std::is_trivially_copyable::value, "must be trivially copyable"); +static_assert(std::is_trivially_destructible::value, "must be trivially destructible"); + +namespace { + +constexpr uint8_t byte_shift = 8; + +constexpr uint8_t f8e5m2_e_size = 5; // f8e5m2 exponent bit size +constexpr uint8_t f8e5m2_e_mask = 0x7c; // f8e5m2 exponent bit mask +constexpr uint8_t f8e5m2_m_size = 2; // f8e5m2 mantissa bits size +constexpr uint8_t f8e5m2_m_mask = 0x03; // f8e5m2 mantissa bit mask + +uint8_t f32_to_f8e5m2_bits(const float value) { + auto f16 = static_cast(value); + reference::func::emulate_f8e5m2_on_fp16(&f16, &f16, 1); + return static_cast((f16.to_bits() >> byte_shift)); +} +} // namespace + +float8_e5m2::float8_e5m2(uint32_t sign, uint32_t biased_exponent, uint32_t fraction) + : m_value((sign & 0x01) << (f8e5m2_e_size + f8e5m2_m_size) | + (biased_exponent & (f8e5m2_e_mask >> f8e5m2_m_size)) << f8e5m2_m_size | (fraction & f8e5m2_m_mask)) {} + +float8_e5m2::float8_e5m2(const float value) : m_value(f32_to_f8e5m2_bits(value)){}; + +float8_e5m2::operator float() const { + return static_cast(float16::from_bits((static_cast(m_value) << byte_shift))); +} + +uint8_t float8_e5m2::to_bits() const { + return m_value; +} +} // namespace ov diff --git a/src/core/tests/element_type.cpp b/src/core/tests/element_type.cpp index abf8b3f8aa7603..d7bd49c2e2e252 100644 --- a/src/core/tests/element_type.cpp +++ b/src/core/tests/element_type.cpp @@ -67,6 +67,10 @@ TEST(element_type, from_string) { EXPECT_EQ(element::Type("U64"), element::u64); EXPECT_EQ(element::Type("nf4"), element::nf4); EXPECT_EQ(element::Type("NF4"), element::nf4); + EXPECT_EQ(element::Type("f8e4m3"), element::f8e4m3); + EXPECT_EQ(element::Type("F8E4M3"), element::f8e4m3); + EXPECT_EQ(element::Type("f8e5m2"), element::f8e5m2); + EXPECT_EQ(element::Type("F8E5M2"), element::f8e5m2); EXPECT_EQ(element::Type("string"), element::string); EXPECT_EQ(element::Type("STRING"), element::string); diff --git a/src/core/tests/eval.cpp b/src/core/tests/eval.cpp index 224272715b4da1..bc77f0c2d653fa 100644 --- a/src/core/tests/eval.cpp +++ b/src/core/tests/eval.cpp @@ -2963,7 +2963,8 @@ TEST(eval, evaluate_fake_convert_f32_to_f8e4m3_no_scale_no_shift) { using namespace testing; constexpr auto et = element::f32; - std::vector input_data{0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f}; + std::vector input_data{0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, + -0.0f, -0.1f, -0.2f, -0.3f, -0.4f, -0.5f, -0.6f, -0.7f, -0.8f, -0.9f, -1.f}; const auto data_shape = Shape{input_data.size()}; auto data = make_shared(et, data_shape); @@ -2983,10 +2984,10 @@ TEST(eval, evaluate_fake_convert_f32_to_f8e4m3_no_scale_no_shift) { EXPECT_EQ(result.get_shape(), data_shape); EXPECT_THAT( read_vector(result), - Pointwise( - FloatEq(), - std::vector< - float>{0.f, 0.1015625f, 0.203125f, 0.3125f, 0.40625f, 0.5f, 0.625f, 0.6875f, 0.8125f, 0.875f, 1.f})); + Pointwise(FloatEq(), std::vector{0.f, 0.1015625f, 0.203125f, 0.3125f, 0.40625f, 0.5f, + 0.625f, 0.6875f, 0.8125f, 0.875f, 1.f, -0.f, + -0.1015625f, -0.203125f, -0.3125f, -0.40625f, -0.5f, -0.625f, + -0.6875f, -0.8125f, -0.875f, -1.f})); } TEST(eval, evaluate_fake_convert_f32_seq_to_f8e4m3_scale_1) { @@ -3223,7 +3224,8 @@ TEST(eval, evaluate_fake_convert_f32_to_f8e5m2_scale_1) { using namespace testing; constexpr auto et = element::f32; - std::vector input_data{0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f}; + std::vector input_data{0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, + -0.0f, -0.1f, -0.2f, -0.3f, -0.4f, -0.5f, -0.6f, -0.7f, -0.8f, -0.9f, -1.f}; const auto data_shape = Shape{input_data.size()}; @@ -3244,11 +3246,11 @@ TEST(eval, evaluate_fake_convert_f32_to_f8e5m2_scale_1) { EXPECT_EQ(result.get_element_type(), et); EXPECT_EQ(result.get_shape(), data_shape); - EXPECT_THAT( - read_vector(result), - Pointwise( - FloatEq(), - std::vector{0.f, 0.09375f, 0.1875f, 0.3125f, 0.375f, 0.5f, 0.625f, 0.75f, 0.75f, 0.875f, 1.f})); + EXPECT_THAT(read_vector(result), + Pointwise(FloatEq(), + std::vector{0.f, 0.09375f, 0.1875f, 0.3125f, 0.375f, 0.5f, 0.625f, 0.75f, + 0.75f, 0.875f, 1.f, -0.f, -0.09375f, -0.1875f, -0.3125f, -0.375f, + -0.5f, -0.625f, -0.75f, -0.75f, -0.875f, -1.f})); } TEST(eval, evaluate_fake_convert_f16_to_f8e5m2_scale_1) { @@ -3707,7 +3709,7 @@ TEST(eval, evaluate_fake_convert_bf16_matching_f8_to_f8e5m2_scale_1) { 4096.f, 5120.f, 6144.f, 7168.f, 8192.f, 10240.f, 12288.f, 14336.f, 16384.f, 20480.f, 24576.f, 28672.f, - 32768.f, 40960.f, 49152.f, 57344.0 + 32768.f, 40960.f, 49152.f, 57344.f }; // clang-format on @@ -3766,6 +3768,92 @@ TEST(eval, evaluate_fake_convert_f32_matching_f8e4m3_to_f8e5m2_scale_1) { EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), output_data)); } +TEST(eval, evaluate_f8e5m2_const_from_f32) { + using namespace testing; + constexpr auto et = element::f8e5m2; + + std::vector input_data{ + 0.017578125f, 0.021484375f, 0.025390625f, 0.029296875f, 0.03515625f, 0.0703125f, 0.140625f, + 0.28125f, 0.5625f, 1.125f, 1.625f, 1.875f, 2.25f, 3.75f, + 4.5f, 9.f, 18.f, 36.f, 72.f, 144.f, 288.f, + }; + /* Rounded to f8e5m2 vals */ + std::vector output_data{0.015625f, 0.0234375f, 0.0234375f, 0.03125f, 0.03125f, 0.0625f, 0.125f, + 0.25f, 0.5f, 1.f, 1.5, 2.f, 2.f, 4.f, + 4.f, 8.f, 16.f, 32.f, 64.f, 128.f, 256.f}; + + const auto data_shape = Shape{input_data.size()}; + + auto op = make_shared(et, data_shape, input_data); + auto model = make_shared(OutputVector{op}, ParameterVector{}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), output_data)); +} + +TEST(eval, evaluate_f8e5m2_const_seq_from_f32) { + using namespace testing; + constexpr auto et = element::f8e5m2; + + std::vector input_data{0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, + -0.0f, -0.1f, -0.2f, -0.3f, -0.4f, -0.5f, -0.6f, -0.7f, -0.8f, -0.9f, -1.f}; + + /* Rounded to f8e5m2 vals */ + std::vector output_data{0.f, 0.09375f, 0.1875f, 0.3125f, 0.375f, 0.5f, 0.625f, 0.75f, + 0.75f, 0.875f, 1.f, -0.f, -0.09375f, -0.1875f, -0.3125f, -0.375f, + -0.5f, -0.625f, -0.75f, -0.75f, -0.875f, -1.f}; + + const auto data_shape = Shape{input_data.size()}; + + auto op = make_shared(et, data_shape, input_data); + auto model = make_shared(OutputVector{op}, ParameterVector{}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), output_data)); +} + +TEST(eval, evaluate_f8e4m3_const_seq_from_f32) { + using namespace testing; + constexpr auto et = element::f8e4m3; + + std::vector input_data{0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, + -0.0f, -0.1f, -0.2f, -0.3f, -0.4f, -0.5f, -0.6f, -0.7f, -0.8f, -0.9f, -1.f}; + + /* Rounded to f8e4m3 vals */ + std::vector output_data{ + 0.f, 0.1015625f, 0.203125f, 0.3125f, 0.40625f, 0.5f, 0.625f, 0.6875f, 0.8125f, 0.875f, 1.f, + -0.f, -0.1015625f, -0.203125f, -0.3125f, -0.40625f, -0.5f, -0.625f, -0.6875f, -0.8125f, -0.875f, -1.f}; + + const auto data_shape = Shape{input_data.size()}; + + auto op = make_shared(et, data_shape, input_data); + auto model = make_shared(OutputVector{op}, ParameterVector{}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), output_data)); +} + TEST(eval, evaluate_fake_convert_f32_seq_to_f8e5m2_scale_shift) { using namespace testing; constexpr auto et = element::f32; diff --git a/src/core/tests/float8_e4m3.cpp b/src/core/tests/float8_e4m3.cpp new file mode 100644 index 00000000000000..6265b530e10105 --- /dev/null +++ b/src/core/tests/float8_e4m3.cpp @@ -0,0 +1,175 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/core/type/float8_e4m3.hpp" + +#include + +#include + +#include "common_test_utils/float_util.hpp" +namespace ov { +namespace test { + +template +std::vector> enumerate(const TContainer& values) { + std::vector> enum_values; + int i = 0; + for (const auto& v : values) { + enum_values.emplace_back(i, v); + ++i; + } + return enum_values; +} + +TEST(F8E4M3Test, f32_inf) { + const auto f8 = ov::float8_e4m3(std::numeric_limits::infinity()); + + EXPECT_EQ(f8.to_bits(), 0x7f); +} + +TEST(F8E4M3Test, f32_minus_inf) { + const auto f8 = ov::float8_e4m3(-std::numeric_limits::infinity()); + // f8 is NaN as there is no infinity + EXPECT_EQ(f8.to_bits(), 0xff); +} + +TEST(F8E4M3Test, f32_nan) { + const auto f8 = ov::float8_e4m3(std::numeric_limits::quiet_NaN()); + + EXPECT_EQ(f8.to_bits(), 0x7f); +} + +TEST(F8E4M3Test, f32_gt_zero_le_f8_half_lowest_subnormal) { + const auto f8 = ov::float8_e4m3(0.0009765625f); + + EXPECT_EQ(f8.to_bits(), 0x00); +} + +TEST(F8E4M3Test, f32_gt_zero_gt_f8_half_lowest_subnormal) { + const auto f8 = ov::float8_e4m3(0.00097656273283064365387f); + + EXPECT_EQ(f8.to_bits(), 0x01); +} + +TEST(F8E4M3Test, f32_normal_fractional_rounding) { + const auto f8 = ov::float8_e4m3(0.129f); + + // Rounded to 0.140625f -> 0x21 + EXPECT_EQ(f8.to_bits(), 0x20); +} + +TEST(F8E4M3Test, f32_normal_negative_fractional_rounding) { + const auto f8 = ov::float8_e4m3(-0.281f); + + // Rounded to -0.28125f -> 0x21 + EXPECT_EQ(f8.to_bits(), 0xa9); +} + +TEST(F8E4M3Test, f32_ge_f8_max_within_round_to_max) { + const auto f8 = ov::float8_e4m3(460.0f); + + // Rounded to 448.0f -> 0x7e + EXPECT_EQ(f8.to_bits(), 0x7e); +} + +TEST(F8E4M3Test, f32_ge_f8_max_not_within_round_to_max) { + const auto f8 = ov::float8_e4m3(560.0f); + + // f8 has no such value (NaN) + EXPECT_EQ(f8.to_bits(), 0x7f); +} + +TEST(F8E4M3Test, f32_le_f8_lowest_within_round_to_lowest) { + const auto f8 = ov::float8_e4m3(-460.0f); + + // Rounded to -448.0f -> 0xfe + EXPECT_EQ(f8.to_bits(), 0xfe); +} + +TEST(F8E4M3Test, f32_le_f8_lowest_not_within_round_to_lowest) { + const auto f8 = ov::float8_e4m3(-760.0f); + + // f8 has no such value (NaN) + EXPECT_EQ(f8.to_bits(), 0xff); +} + +TEST(F8E4M3Test, stream_operator) { + std::stringstream s; + s << ov::float8_e4m3(2.5f); + + EXPECT_EQ(s.str(), "2.5"); +} + +TEST(F8E4M3Test, to_string) { + const auto f8 = ov::float8_e4m3::from_bits(0b00111010); + + EXPECT_EQ(std::to_string(f8), "1.250000"); +} +constexpr auto f32_qnan = std::numeric_limits::quiet_NaN(); + +const auto exp_floats = std::vector{ + 0.0f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f, + 0.015625f, 0.017578125f, 0.01953125f, 0.021484375f, 0.0234375f, 0.025390625f, 0.02734375f, 0.029296875f, + 0.03125f, 0.03515625f, 0.0390625f, 0.04296875f, 0.046875f, 0.05078125f, 0.0546875f, 0.05859375f, + 0.0625f, 0.0703125f, 0.078125f, 0.0859375f, 0.09375f, 0.1015625f, 0.109375f, 0.1171875f, + 0.125f, 0.140625f, 0.15625f, 0.171875f, 0.1875f, 0.203125f, 0.21875f, 0.234375f, + 0.25f, 0.28125f, 0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f, 0.46875f, + 0.5f, 0.5625f, 0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f, 0.9375f, + 1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f, + 2.0f, 2.25f, 2.5f, 2.75f, 3.0f, 3.25f, 3.5f, 3.75f, + 4.0f, 4.5f, 5.0f, 5.5f, 6.0f, 6.5f, 7.0f, 7.5f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f, + 32.0f, 36.0f, 40.0f, 44.0f, 48.0f, 52.0f, 56.0f, 60.0f, + 64.0f, 72.0f, 80.0f, 88.0f, 96.0f, 104.0f, 112.0f, 120.0f, + 128.0f, 144.0f, 160.0f, 176.0f, 192.0f, 208.0f, 224.0f, 240.0f, + 256.0f, 288.0f, 320.0f, 352.0f, 384.0f, 416.0f, 448.0f, f32_qnan, + -0.0f, -0.001953125f, -0.00390625f, -0.005859375f, -0.0078125f, -0.009765625f, -0.01171875f, -0.013671875f, + -0.015625f, -0.017578125f, -0.01953125f, -0.021484375f, -0.0234375f, -0.025390625f, -0.02734375f, -0.029296875f, + -0.03125f, -0.03515625f, -0.0390625f, -0.04296875f, -0.046875f, -0.05078125f, -0.0546875f, -0.05859375f, + -0.0625f, -0.0703125f, -0.078125f, -0.0859375f, -0.09375f, -0.1015625f, -0.109375f, -0.1171875f, + -0.125f, -0.140625f, -0.15625f, -0.171875f, -0.1875f, -0.203125f, -0.21875f, -0.234375f, + -0.25f, -0.28125f, -0.3125f, -0.34375f, -0.375f, -0.40625f, -0.4375f, -0.46875f, + -0.5f, -0.5625f, -0.625f, -0.6875f, -0.75f, -0.8125f, -0.875f, -0.9375f, + -1.0f, -1.125f, -1.25f, -1.375f, -1.5f, -1.625f, -1.75f, -1.875f, + -2.0f, -2.25f, -2.5f, -2.75f, -3.0f, -3.25f, -3.5f, -3.75f, + -4.0f, -4.5f, -5.0f, -5.5f, -6.0f, -6.5f, -7.0f, -7.5f, + -8.0f, -9.0f, -10.0f, -11.0f, -12.0f, -13.0f, -14.0f, -15.0f, + -16.0f, -18.0f, -20.0f, -22.0f, -24.0f, -26.0f, -28.0f, -30.0f, + -32.0f, -36.0f, -40.0f, -44.0f, -48.0f, -52.0f, -56.0f, -60.0f, + -64.0f, -72.0f, -80.0f, -88.0f, -96.0f, -104.0f, -112.0f, -120.0f, + -128.0f, -144.0f, -160.0f, -176.0f, -192.0f, -208.0f, -224.0f, -240.0f, + -256.0f, -288.0f, -320.0f, -352.0f, -384.0f, -416.0f, -448.0f, -f32_qnan}; + +using f8m4e3_params = std::tuple; +class F8E4M3PTest : public testing::TestWithParam {}; + +INSTANTIATE_TEST_SUITE_P(convert, + F8E4M3PTest, + testing::ValuesIn(enumerate(exp_floats)), + testing::PrintToStringParamName()); + +TEST_P(F8E4M3PTest, f8_bits_to_f32) { + const auto& params = GetParam(); + const auto& exp_value = std::get<1>(params); + const auto f8 = ov::float8_e4m3::from_bits(std::get<0>(params)); + + if (std::isnan(exp_value)) { + EXPECT_TRUE(std::isnan(static_cast(f8))); + } else { + EXPECT_EQ(static_cast(f8), exp_value); + } +} + +TEST_P(F8E4M3PTest, f32_to_f8_bits) { + const auto& params = GetParam(); + const auto& exp_value = std::get<0>(params); + const auto& value = std::get<1>(params); + const auto f8 = ov::float8_e4m3(value); + + EXPECT_EQ(f8.to_bits(), exp_value); +} +} // namespace test +} // namespace ov diff --git a/src/core/tests/float8_e5m2.cpp b/src/core/tests/float8_e5m2.cpp new file mode 100644 index 00000000000000..38028497341c17 --- /dev/null +++ b/src/core/tests/float8_e5m2.cpp @@ -0,0 +1,176 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/core/type/float8_e5m2.hpp" + +#include + +#include + +#include "common_test_utils/float_util.hpp" + +namespace ov { +namespace test { + +TEST(F8E5M2Test, stream_operator) { + std::stringstream s; + s << ov::float8_e5m2(2.5f); + + EXPECT_EQ(s.str(), "2.5"); +} + +TEST(F8E5M2Test, to_string) { + const auto f8 = ov::float8_e5m2::from_bits(0b00111010); + + EXPECT_EQ(std::to_string(f8), "0.750000"); +} + +TEST(F8E5M2Test, f32_inf) { + const auto f8 = ov::float8_e5m2(std::numeric_limits::infinity()); + + EXPECT_EQ(f8.to_bits(), 0b01111100); +} + +TEST(F8E5M2Test, f32_minus_inf) { + const auto f8 = ov::float8_e5m2(-std::numeric_limits::infinity()); + + EXPECT_EQ(f8.to_bits(), 0b11111100); +} + +TEST(F8E5M2Test, f32_ge_f8_max_round_to_inf) { + const auto f8 = ov::float8_e5m2(65520.0f); + + EXPECT_EQ(f8.to_bits(), 0b01111100); +} + +TEST(F8E5M2Test, f32_ge_f8_max_round_to_max) { + const auto f8 = ov::float8_e5m2(65519.9f); + + EXPECT_EQ(f8.to_bits(), 0b01111011); +} + +TEST(F8E5M2Test, f32_ge_f8_max_round_to_minus_inf) { + const auto f8 = ov::float8_e5m2(-65520.0f); + + EXPECT_EQ(f8.to_bits(), 0b11111100); +} + +TEST(F8E5M2Test, f32_ge_f8_max_round_to_lowest) { + const auto f8 = ov::float8_e5m2(-65519.9f); + + EXPECT_EQ(f8.to_bits(), 0b11111011); +} + +template +std::vector> enumerate(const TContainer& values) { + std::vector> enum_values; + int i = 0; + for (const auto& v : values) { + enum_values.emplace_back(i, v); + ++i; + } + return enum_values; +} + +constexpr auto f32_qnan = std::numeric_limits::quiet_NaN(); +constexpr auto f32_inf = std::numeric_limits::infinity(); + +// clang-format off +const auto exp_floats = std::vector{ + 0.0f, 1.52587890625e-05f, 3.0517578125e-05f, 4.57763671875e-05f, + 6.103515625e-05f, 7.62939453125e-05f, 9.1552734375e-05f, 0.0001068115234375f, + 0.0001220703125f, 0.000152587890625f, 0.00018310546875f, 0.000213623046875f, + 0.000244140625f, 0.00030517578125f, 0.0003662109375f, 0.00042724609375f, + 0.00048828125f, 0.0006103515625f, 0.000732421875f, 0.0008544921875f, + 0.0009765625f, 0.001220703125f, 0.00146484375f, 0.001708984375f, + 0.001953125f, 0.00244140625f, 0.0029296875f, 0.00341796875f, + 0.00390625f, 0.0048828125f, 0.005859375f, 0.0068359375f, + 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f, + 0.015625f, 0.01953125f, 0.0234375f, 0.02734375f, + 0.03125f, 0.0390625f, 0.046875f, 0.0546875f, + 0.0625f, 0.078125f, 0.09375f, 0.109375f, + 0.125f, 0.15625f, 0.1875f, 0.21875f, + 0.25f, 0.3125f, 0.375f, 0.4375f, + 0.5f, 0.625f, 0.75f, 0.875f, + 1.0f, 1.25f, 1.5f, 1.75f, + 2.0f, 2.5f, 3.0f, 3.5f, + 4.0f, 5.0f, 6.0f, 7.0f, + 8.0f, 10.0f, 12.0f, 14.0f, + 16.0f, 20.0f, 24.0f, 28.0f, + 32.0f, 40.0f, 48.0f, 56.0f, + 64.0f, 80.0f, 96.0f, 112.0f, + 128.0f, 160.0f, 192.0f, 224.0f, + 256.0f, 320.0f, 384.0f, 448.0f, + 512.0f, 640.0f, 768.0f, 896.0f, + 1024.0f, 1280.0f, 1536.0f, 1792.0f, + 2048.0f, 2560.0f, 3072.0f, 3584.0f, + 4096.0f, 5120.0f, 6144.0f, 7168.0f, + 8192.0f, 10240.0f, 12288.0f, 14336.0f, + 16384.0f, 20480.0f, 24576.0f, 28672.0f, + 32768.0f, 40960.0f, 49152.0f, 57344.0f, + f32_inf, f32_qnan, f32_qnan, f32_qnan, + -0.0f, -1.52587890625e-05f, -3.0517578125e-05f, -4.57763671875e-05f, + -6.103515625e-05f, -7.62939453125e-05f, -9.1552734375e-05f, -0.0001068115234375f, + -0.0001220703125f, -0.000152587890625f, -0.00018310546875f, -0.000213623046875f, + -0.000244140625f, -0.00030517578125f, -0.0003662109375f, -0.00042724609375f, + -0.00048828125f, -0.0006103515625f, -0.000732421875f, -0.0008544921875f, + -0.0009765625f, -0.001220703125f, -0.00146484375f, -0.001708984375f, + -0.001953125f, -0.00244140625f, -0.0029296875f, -0.00341796875f, + -0.00390625f, -0.0048828125f, -0.005859375f, -0.0068359375f, + -0.0078125f, -0.009765625f, -0.01171875f, -0.013671875f, + -0.015625f, -0.01953125f, -0.0234375f, -0.02734375f, + -0.03125f, -0.0390625f, -0.046875f, -0.0546875f, + -0.0625f, -0.078125f, -0.09375f, -0.109375f, + -0.125f, -0.15625f, -0.1875f, -0.21875f, + -0.25f, -0.3125f, -0.375f, -0.4375f, + -0.5f, -0.625f, -0.75f, -0.875f, + -1.0f, -1.25f, -1.5f, -1.75f, + -2.0f, -2.5f, -3.0f, -3.5f, + -4.0f, -5.0f, -6.0f, -7.0f, + -8.0f, -10.0f, -12.0f, -14.0f, + -16.0f, -20.0f, -24.0f, -28.0f, + -32.0f, -40.0f, -48.0f, -56.0f, + -64.0f, -80.0f, -96.0f, -112.0f, + -128.0f, -160.0f, -192.0f, -224.0f, + -256.0f, -320.0f, -384.0f, -448.0f, + -512.0f, -640.0f, -768.0f, -896.0f, + -1024.0f, -1280.0f, -1536.0f, -1792.0f, + -2048.0f, -2560.0f, -3072.0f, -3584.0f, + -4096.0f, -5120.0f, -6144.0f, -7168.0f, + -8192.0f, -10240.0f, -12288.0f, -14336.0f, + -16384.0f, -20480.0f, -24576.0f, -28672.0f, + -32768.0f, -40960.0f, -49152.0f, -57344.0f, + -f32_inf, -f32_qnan, -f32_qnan, -f32_qnan}; +// clang-format on + +using f8m5e2_params = std::tuple; +class F8E5M2PTest : public testing::TestWithParam {}; + +INSTANTIATE_TEST_SUITE_P(convert, + F8E5M2PTest, + testing::ValuesIn(enumerate(exp_floats)), + testing::PrintToStringParamName()); + +TEST_P(F8E5M2PTest, f8_bits_to_f32) { + const auto& params = GetParam(); + const auto& exp_value = std::get<1>(params); + const auto f8 = ov::float8_e5m2::from_bits(std::get<0>(params)); + + if (std::isnan(exp_value)) { + EXPECT_TRUE(std::isnan(static_cast(f8))); + } else { + EXPECT_EQ(static_cast(f8), exp_value); + } +} + +TEST_P(F8E5M2PTest, f32_to_f8_bits) { + const auto& params = GetParam(); + const auto& value = std::get<1>(params); + const auto& exp_value = std::isnan(value) ? (std::signbit(value) ? 0xfe : 0x7e) : std::get<0>(params); + const auto f8 = ov::float8_e5m2(value); + + EXPECT_EQ(f8.to_bits(), exp_value); +} +} // namespace test +} // namespace ov diff --git a/src/plugins/template/tests/functional/op_reference/base_reference_test.cpp b/src/plugins/template/tests/functional/op_reference/base_reference_test.cpp index d5f5a3a4f19b96..ed8621d0351a3e 100644 --- a/src/plugins/template/tests/functional/op_reference/base_reference_test.cpp +++ b/src/plugins/template/tests/functional/op_reference/base_reference_test.cpp @@ -105,6 +105,22 @@ void CommonReferenceTest::ValidateBlobs(const ov::Tensor& refBlob, threshold, abs_threshold); break; + case ov::element::f8e4m3: + LayerTestsUtils::LayerTestsCommon::Compare( + refBlob.data(), + outBlob.data(), + actual_comparision_size, + threshold, + abs_threshold); + break; + case ov::element::f8e5m2: + LayerTestsUtils::LayerTestsCommon::Compare( + refBlob.data(), + outBlob.data(), + actual_comparision_size, + threshold, + abs_threshold); + break; case ov::element::f32: LayerTestsUtils::LayerTestsCommon::Compare(refBlob.data(), outBlob.data(), diff --git a/src/plugins/template/tests/functional/op_reference/constant.cpp b/src/plugins/template/tests/functional/op_reference/constant.cpp index ababb046e9c2b6..8178544a7d940a 100644 --- a/src/plugins/template/tests/functional/op_reference/constant.cpp +++ b/src/plugins/template/tests/functional/op_reference/constant.cpp @@ -201,6 +201,50 @@ std::vector generateConstantDefinedTypeParams() { std::vector{0x4000000000000001, 0x4000000000000002}, std::vector{0x4000000000000001, 0x4000000000000002}, "tensor_constant_int64"), + ConstantParams( + {3, 9}, + element::Type_t::f8e4m3, + element::Type_t::f8e4m3, + std::vector{4.75f, 4.5f, -5.25f, 0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, + 0.6f, 0.7f, 0.8f, 0.9f, 1.f, -0.0f, -0.1f, -0.2f, -0.3f, + -0.4f, -0.5f, -0.6f, -0.7f, -0.8f, -0.9f, -1.f, 0.001953125f, 448.f}, + std::vector{5.0f, 4.5f, -5.0f, 0.0f, 0.1015625f, 0.203125f, 0.3125f, + 0.40625f, 0.5f, 0.625f, 0.6875f, 0.8125f, 0.875f, 1.f, + -0.f, -0.1015625f, -0.203125f, -0.3125f, -0.40625f, -0.5f, -0.625f, + -0.6875f, -0.8125f, -0.875f, -1.f, 0.001953125f, 448.f}, + "tensor_constant_f8e4m3"), + ConstantParams({3, 9}, + element::Type_t::f8e5m2, + element::Type_t::f8e5m2, + std::vector{4.75f, 4.5f, + -5.25f, 0.0f, + 0.1f, 0.2f, + 0.3f, 0.4f, + 0.5f, 0.6f, + 0.7f, 0.8f, + 0.9f, 1.f, + -0.0f, -0.1f, + -0.2f, -0.3f, + -0.4f, -0.5f, + -0.6f, -0.7f, + -0.8f, -0.9f, + -1.f, 0.0000152587890625f, + 57344.f}, + std::vector{4.75f, 4.5f, + -5.25f, 0.0f, + 0.09375f, 0.1875f, + 0.3125f, 0.375f, + 0.5f, 0.625f, + 0.75f, 0.75f, + 0.875f, 1.f, + -0.f, -0.09375f, + -0.1875f, -0.3125f, + -0.375f, -0.5f, + -0.625f, -0.75f, + -0.75f, -0.875f, + -1.f, 0.0000152587890625f, + 57344.f}, + "tensor_constant_f8e5m2"), }; return constantParams; } diff --git a/src/plugins/template/tests/functional/op_reference/convert.cpp b/src/plugins/template/tests/functional/op_reference/convert.cpp index b6195744c9c6f3..461daa56d80b14 100644 --- a/src/plugins/template/tests/functional/op_reference/convert.cpp +++ b/src/plugins/template/tests/functional/op_reference/convert.cpp @@ -57,6 +57,163 @@ INSTANTIATE_TEST_SUITE_P( std::numeric_limits::infinity(), -std::numeric_limits::infinity()}, std::vector{0, 1, 1, 0, 1, 1, 1, 1, 1}), + + ConvertParams(ConversionTypes::CONVERT, + ov::PartialShape{1, 1, 3, 7}, + ov::element::f32, + ov::element::f8e5m2, + std::vector{ + 0.017578125f, 0.021484375f, 0.025390625f, 0.029296875f, 0.03515625f, 0.0703125f, 0.140625f, + 0.28125f, 0.5625f, 1.125f, 1.625f, 1.875f, 2.25f, 3.75f, + 4.5f, 9.f, 18.f, 36.f, 72.f, 144.f, 288.f}, + std::vector{0.015625f, 0.0234375f, 0.0234375f, 0.03125f, 0.03125f, 0.0625f, + 0.125f, 0.25f, 0.5f, 1.f, 1.5, 2.f, + 2.f, 4.f, 4.f, 8.f, 16.f, 32.f, + 64.f, 128.f, 256.f}), + ConvertParams(ConversionTypes::CONVERT, + ov::PartialShape{1, 1, 3, 7}, + ov::element::f8e5m2, + ov::element::f32, + std::vector{0.015625f, 0.0234375f, 0.0234375f, 0.03125f, 0.03125f, 0.0625f, + 0.125f, 0.25f, 0.5f, 1.f, 1.5, 2.f, + 2.f, 4.f, 4.f, 8.f, 16.f, 32.f, + 64.f, 128.f, 256.f}, + std::vector{0.015625f, 0.0234375f, 0.0234375f, 0.03125f, 0.03125f, 0.0625f, 0.125f, + 0.25f, 0.5f, 1.f, 1.5, 2.f, 2.f, 4.f, + 4.f, 8.f, 16.f, 32.f, 64.f, 128.f, 256.f}), + ConvertParams(ConversionTypes::CONVERT, + ov::PartialShape{1, 1, 7}, + ov::element::f16, + ov::element::f8e5m2, + std::vector{0.f, -0.f, 0.5f, 1.5f, 2.5f, 1.5f, 3.5f}, + std::vector{0.f, -0.f, 0.5f, 1.5f, 2.5f, 1.5f, 3.5f}), + ConvertParams(ConversionTypes::CONVERT, + ov::PartialShape{1, 1, 7}, + ov::element::f8e5m2, + ov::element::f16, + std::vector{0.f, -0.f, 0.5f, 1.5f, 2.5f, 1.5f, 3.5f}, + std::vector{0.f, -0.f, 0.5f, 1.5f, 2.5f, 1.5f, 3.5f}), + + ConvertParams(ConversionTypes::CONVERT, + ov::PartialShape{1, 3, 5}, + ov::element::f16, + ov::element::f8e4m3, + std::vector{0.0f, + 0.1f, + 0.2f, + 0.3f, + 0.4f, + 0.5f, + 0.6f, + 0.7f, + 0.8f, + 0.9f, + 1.f, + 1.5f, + 2.5f, + 1.5f, + 3.5f}, + std::vector{0.f, + 0.1015625f, + 0.203125f, + 0.3125f, + 0.40625f, + 0.5f, + 0.625f, + 0.6875f, + 0.8125f, + 0.875f, + 1.f, + 1.5f, + 2.5f, + 1.5f, + 3.5f}), + + ConvertParams(ConversionTypes::CONVERT, + ov::PartialShape{1, 1, 3}, + + ov::element::f8e4m3, + ov::element::f16, + std::vector{0.5f, 1.5f, 0.f}, + std::vector{0.5f, 1.5f, 0.f}), + + ConvertParams(ConversionTypes::CONVERT, + ov::PartialShape{1, 1, 3}, + + ov::element::f8e4m3, + ov::element::f8e4m3, + std::vector{0.5f, 1.5f, 0.f}, + std::vector{0.5f, 1.5f, 0.f}), + + ConvertParams(ConversionTypes::CONVERT, + ov::PartialShape{1, 1, 2}, + + ov::element::f8e5m2, + ov::element::f8e5m2, + std::vector{0.5f, 1.5f}, + std::vector{0.5f, 1.5f}), + + ConvertParams(ConversionTypes::CONVERT, + ov::PartialShape{1, 1, 2}, + + ov::element::f32, + ov::element::f32, + std::vector{0.5f, 1.5f}, + std::vector{0.5f, 1.5f}), + + ConvertParams(ConversionTypes::CONVERT, + ov::PartialShape{1, 1, 2}, + + ov::element::f8e4m3, + ov::element::f32, + std::vector{0.5f, 1.5f}, + std::vector{0.5f, 1.5f}), + + ConvertParams(ConversionTypes::CONVERT, + ov::PartialShape{1, 1, 2}, + + ov::element::f8e4m3, + ov::element::f16, + std::vector{0.5f, 1.5f}, + std::vector{0.5f, 1.5f}), + + ConvertParams(ConversionTypes::CONVERT, + ov::PartialShape{1, 1, 3, 2}, + ov::element::f8e4m3, + ov::element::f32, + std::vector{ + 0.5f, + 1.5f, + 0.5f, + 2.5f, + 1.5f, + 3.5f, + }, + std::vector{0.5f, 1.5f, 0.5f, 2.5f, 1.5f, 3.5f}), + + ConvertParams( + ConversionTypes::CONVERT, + ov::PartialShape{1, 1, 3, 5}, + ov::element::f32, + ov::element::f8e4m3, + std:: + vector{0.5f, 1.5f, 0.5f, 2.5f, 1.5f, 0.5f, 3.5f, 2.5f, 0.5f, 0.5f, 2.5f, 0.5f, 0.5f, 0.5f, 1.5f}, + std::vector{0.5f, + 1.5f, + 0.5f, + 2.5f, + 1.5f, + 0.5f, + 3.5f, + 2.5f, + 0.5f, + 0.5f, + 2.5f, + 0.5f, + 0.5f, + 0.5f, + 1.5f}), + // destination bf16 ConvertParams( ConversionTypes::CONVERT, diff --git a/src/tests/test_utils/common_test_utils/include/common_test_utils/data_utils.hpp b/src/tests/test_utils/common_test_utils/include/common_test_utils/data_utils.hpp index 4f318ef98b3f03..28a26f4d754a12 100644 --- a/src/tests/test_utils/common_test_utils/include/common_test_utils/data_utils.hpp +++ b/src/tests/test_utils/common_test_utils/include/common_test_utils/data_utils.hpp @@ -525,6 +525,14 @@ inline ov::float16 ie_abs(const ov::float16& val) { return ov::float16::from_bits(val.to_bits() & 0x7FFF); } +inline ov::float8_e4m3 ie_abs(const ov::float8_e4m3& val) { + return ov::float8_e4m3::from_bits(val.to_bits() & 0x7F); +} + +inline ov::float8_e5m2 ie_abs(const ov::float8_e5m2& val) { + return ov::float8_e5m2::from_bits(val.to_bits() & 0x7F); +} + } // namespace utils } // namespace test } // namespace ov