Skip to content

Commit

Permalink
Use NaN from std::numeric_limits instead macro
Browse files Browse the repository at this point in the history
- minor refactor of float8_e4m3
  • Loading branch information
praasz authored and beleiuandrei committed Jan 15, 2024
1 parent 2121c38 commit 00771c4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
24 changes: 10 additions & 14 deletions src/core/src/type/float8_e4m3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ constexpr uint8_t f8e4m3_e_max = 0x0f; // f8e4m3 exponent max value
constexpr uint8_t f8e4m3_m_size = 3; // f8e4m3 mantissa bits size
constexpr uint8_t f8e4m3_m_mask = 0x07; // f8e4m3 mantissa bit mask

union f32_t {
float value;
uint32_t bits;
};

uint8_t f32_to_f8e4m3_bits(const float value) {
constexpr uint32_t f32_s_mask = 0x80000000; // f32 sign bit mask
constexpr uint32_t f32_e_mask = 0x7F800000; // f32 exponent bits mask
Expand All @@ -63,11 +68,6 @@ uint8_t f32_to_f8e4m3_bits(const float value) {
constexpr uint32_t round_even = 0x00800000; // value for half to even round for f8
constexpr uint32_t round_odd = 0x01800000; // value for an non-half to even round for f8

union f32_t {
float value;
uint32_t bits;
};

const auto input = f32_t{value};
auto f8_bits = static_cast<uint8_t>((input.bits & f32_s_mask) >> three_bytes_shift);

Expand All @@ -90,11 +90,11 @@ uint8_t f32_to_f8e4m3_bits(const float value) {
fractional &= f8_m_mask;

// set exponent and mantissa on f8 bits
if (f8_biased_exp > 15) {
if (f8_biased_exp > f8e4m3_e_max) {
// Use NAN as this type has no infinity
f8_bits |= (f8e4m3_e_mask | f8e4m3_m_mask);
} else if (f8_biased_exp > 0) {
f8_bits |= ((f8_biased_exp) << f8e4m3_m_size) | (fractional >> three_bytes_shift);
f8_bits |= (f8_biased_exp << f8e4m3_m_size) | (fractional >> three_bytes_shift);
} else {
// Restore the hidden 1 in f8 mantissa for subnormal calculation
fractional = f8_m_hidden_one_mask | (input.bits & f32_m_mask) << (f32_e_size - f8e4m3_e_size);
Expand Down Expand Up @@ -124,13 +124,9 @@ float8_e4m3::float8_e4m3(const uint32_t sign, const uint32_t biased_exponent, co
float8_e4m3::float8_e4m3(const float value) : m_value{f32_to_f8e4m3_bits(value)} {}

float8_e4m3::operator float() const {
union {
float float_value;
uint32_t bit_value;
};
float_value = f8_to_float_lut[m_value & (f8e4m3_e_mask | f8e4m3_m_mask)];
bit_value |= (m_value & f8e4m3_s_mask) << three_bytes_shift;
return float_value;
auto converted = f32_t{f8_to_float_lut[m_value & (f8e4m3_e_mask | f8e4m3_m_mask)]};
converted.bits |= (m_value & f8e4m3_s_mask) << three_bytes_shift;
return converted.value;
}

uint8_t float8_e4m3::to_bits() const {
Expand Down
5 changes: 3 additions & 2 deletions src/core/tests/float8_e4m3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ TEST(F8E4M3Test, to_string) {

EXPECT_EQ(std::to_string(f8), "1.250000");
}
constexpr auto f32_qnan = std::numeric_limits<float>::quiet_NaN();

const auto exp_floats = std::vector<float>{
0.0f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f,
Expand All @@ -124,7 +125,7 @@ const auto exp_floats = std::vector<float>{
32.0f, 36.0f, 40.0f, 44.0f, 48.0f, 52.0f, 56.0f, 60.0f,
64.0f, 72.0f, 80.0f, 88.0f, 96.0f, 104.0f, 112.0f, 120.0f,
128.0f, 144.0f, 160.0f, 176.0f, 192.0f, 208.0f, 224.0f, 240.0f,
256.0f, 288.0f, 320.0f, 352.0f, 384.0f, 416.0f, 448.0f, NAN,
256.0f, 288.0f, 320.0f, 352.0f, 384.0f, 416.0f, 448.0f, f32_qnan,
-0.0f, -0.001953125f, -0.00390625f, -0.005859375f, -0.0078125f, -0.009765625f, -0.01171875f, -0.013671875f,
-0.015625f, -0.017578125f, -0.01953125f, -0.021484375f, -0.0234375f, -0.025390625f, -0.02734375f, -0.029296875f,
-0.03125f, -0.03515625f, -0.0390625f, -0.04296875f, -0.046875f, -0.05078125f, -0.0546875f, -0.05859375f,
Expand All @@ -140,7 +141,7 @@ const auto exp_floats = std::vector<float>{
-32.0f, -36.0f, -40.0f, -44.0f, -48.0f, -52.0f, -56.0f, -60.0f,
-64.0f, -72.0f, -80.0f, -88.0f, -96.0f, -104.0f, -112.0f, -120.0f,
-128.0f, -144.0f, -160.0f, -176.0f, -192.0f, -208.0f, -224.0f, -240.0f,
-256.0f, -288.0f, -320.0f, -352.0f, -384.0f, -416.0f, -448.0f, -NAN};
-256.0f, -288.0f, -320.0f, -352.0f, -384.0f, -416.0f, -448.0f, -f32_qnan};

using f8m4e3_params = std::tuple<int, float>;
class F8E4M3PTest : public testing::TestWithParam<f8m4e3_params> {};
Expand Down
7 changes: 5 additions & 2 deletions src/core/tests/float8_e5m2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ std::vector<std::tuple<int, typename TContainer::value_type>> enumerate(const TC
return enum_values;
}

constexpr auto f32_qnan = std::numeric_limits<float>::quiet_NaN();
constexpr auto f32_inf = std::numeric_limits<float>::infinity();

// clang-format off
const auto exp_floats = std::vector<float>{
0.0f, 1.52587890625e-05f, 3.0517578125e-05f, 4.57763671875e-05f,
Expand Down Expand Up @@ -70,7 +73,7 @@ const auto exp_floats = std::vector<float>{
8192.0f, 10240.0f, 12288.0f, 14336.0f,
16384.0f, 20480.0f, 24576.0f, 28672.0f,
32768.0f, 40960.0f, 49152.0f, 57344.0f,
INFINITY, NAN, NAN, NAN,
f32_inf, f32_qnan, f32_qnan, f32_qnan,
-0.0f, -1.52587890625e-05f, -3.0517578125e-05f, -4.57763671875e-05f,
-6.103515625e-05f, -7.62939453125e-05f, -9.1552734375e-05f, -0.0001068115234375f,
-0.0001220703125f, -0.000152587890625f, -0.00018310546875f, -0.000213623046875f,
Expand Down Expand Up @@ -102,7 +105,7 @@ const auto exp_floats = std::vector<float>{
-8192.0f, -10240.0f, -12288.0f, -14336.0f,
-16384.0f, -20480.0f, -24576.0f, -28672.0f,
-32768.0f, -40960.0f, -49152.0f, -57344.0f,
-INFINITY, -NAN, -NAN, -NAN};
-f32_inf, -f32_qnan, -f32_qnan, -f32_qnan};
// clang-format on

using f8m5e2_params = std::tuple<int, float>;
Expand Down

0 comments on commit 00771c4

Please sign in to comment.