Skip to content

Commit

Permalink
Add numeric_limits for MLFloat16 and BFloat16 (microsoft#22197)
Browse files Browse the repository at this point in the history
### Description
* Add std::numeric_limits for MLFloat16 and BFloat16.
* Update some comments in csharp ORTFloat16.shared.cs.
* Add unit tests (including Clip)

Note that the canonical NaN is not consistent in C++ and C#. C# uses
negative quiet NaN as canonical NaN, while C++ uses positive quiet NaN.
The choice of CSharp Float16.NaN is to be consistent with
System.Half.NaN.

FP16 data returns from CUDA might have 7FFF as NaN; FP16 data from CPU
provider might have 0x7E00 as NaN. Anyway there is no consistent
canonical NaN in ORT right now. Because all these NaNs are aligned with
IEEE spec, there shall not an issue in downstream.

### Motivation and Context
std::numeric_limits is used in codebase but not defined for MLFloat16
and BFloat16. It causes some bugs like
microsoft#21957 introduced by
microsoft#21493.
  • Loading branch information
tianleiwu authored Sep 26, 2024
1 parent 72b0979 commit 7880342
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 57 deletions.
70 changes: 37 additions & 33 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtFloat16.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ internal static int LeadingZeroCount(uint num)
/// <summary>
/// Extracts single precision number bit representation as uint
/// so its bits can be manipulated.
///
///
/// This API is the reverse of UInt32BitsToSingle().
///
///
/// </summary>
/// <param name="single">float value</param>
/// <returns></returns>
Expand All @@ -79,11 +79,11 @@ internal static uint SingleToUInt32Bits(float single)
/// <summary>
/// Needed because BitConverter impl is not available until
/// later versions. This API is the reverse of SingleToUInt32Bits().
///
///
/// For the exact bit representation of float see IEEE 754 standard for single precision.
///
///
/// </summary>
/// <param name="singleBits">bit representation of float either obtained from
/// <param name="singleBits">bit representation of float either obtained from
/// SingleToUInt32Bits or assembled using bitwise operators</param>
/// <returns></returns>
internal static float UInt32BitsToSingle(uint singleBits)
Expand All @@ -99,7 +99,7 @@ internal static float UInt32BitsToSingle(uint singleBits)
/// <summary>
/// Converts single precision bits representation which can be obtained using
/// SingleToUInt32Bits() or manually constructed according to IEEE 754 standard.
///
///
/// </summary>
/// <param name="singleBits">bits representation of a single precision number (float)</param>
/// <returns></returns>
Expand Down Expand Up @@ -177,8 +177,8 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand)
/// do not have to be copied to be passed to native memory but simply pinned and read by native code. Thus,
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
///
/// The implementation is derived from
///
/// The implementation is derived from
/// https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Half.cs,7895d5942d33f974
/// </summary>
[StructLayout(LayoutKind.Sequential)]
Expand Down Expand Up @@ -215,6 +215,7 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand)

private const ushort OneBits = 0x3C00;

// Minimum positive normalized value. It is corresponding to numeric_limits<float16>::min() in C++.
private const ushort EpsilonBits = 0x0400;

private const ushort PositiveInfinityBits = 0x7C00;
Expand All @@ -238,7 +239,7 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand)
/// <summary>
/// Float16 Epsilon value
/// </summary>
public static Float16 Epsilon => new Float16(EpsilonBits); // 5.9604645E-08
public static Float16 Epsilon => new Float16(EpsilonBits); // 0.00006103515625

/// <summary>
/// Float16 Pi value
Expand All @@ -248,17 +249,17 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand)
/// <summary>
/// Float16 Positive Infinity value
/// </summary>
public static Float16 PositiveInfinity => new Float16(PositiveInfinityBits); // 1.0 / 0.0;
public static Float16 PositiveInfinity => new Float16(PositiveInfinityBits);

/// <summary>
/// Float16 Negative Infinity value
/// </summary>
public static Float16 NegativeInfinity => new Float16(NegativeInfinityBits); // -1.0 / 0.0
public static Float16 NegativeInfinity => new Float16(NegativeInfinityBits);

/// <summary>
/// Float16 NaN
/// </summary>
public static Float16 NaN => new Float16(NegativeQNaNBits); // 0.0 / 0.0
public static Float16 NaN => new Float16(NegativeQNaNBits); // Same as System.Half.NaN

/// <summary>
/// Float16 Zero value
Expand All @@ -276,14 +277,14 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand)
public static Float16 NegativeZero => new Float16(NegativeZeroBits); // -0.0

/// <summary>
/// Float16 Min value
/// Float16 Lowest value
/// </summary>
public static Float16 MinValue => new Float16(MinValueBits); // 64,511
public static Float16 MinValue => new Float16(MinValueBits); // -65504.0

/// <summary>
/// Float16 Max value
/// </summary>
public static Float16 MaxValue => new Float16(MaxValueBits); // 31,743
public static Float16 MaxValue => new Float16(MaxValueBits); // 65504.0

/// <summary>
/// float16 representation bits
Expand Down Expand Up @@ -348,7 +349,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits)

/// <summary>
/// Compares values of two Float16
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
Expand Down Expand Up @@ -376,7 +377,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits)

/// <summary>
/// Compares values of two Float16
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
Expand All @@ -388,7 +389,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits)

/// <summary>
/// Compares values of two Float16
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
Expand Down Expand Up @@ -429,7 +430,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits)
/// <summary>
/// Compares values of two Float16 for binary equality.
/// If either of the values is NaN, this will return false.
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
Expand Down Expand Up @@ -479,7 +480,7 @@ public static bool IsInfinity(Float16 value)
/// <summary>
/// Determines whether the specified value is NaN.
/// </summary>
///
///
/// <param name="value">Float16 instance</param>
/// <returns>true if the value is not a number</returns>
public static bool IsNaN(Float16 value)
Expand All @@ -500,7 +501,7 @@ public static bool IsNegative(Float16 value)
/// <summary>
/// Determines whether the specified value is negative infinity.
/// </summary>
///
///
/// <param name="value">Float16 instance</param>
/// <returns>true if the value is negative infinity</returns>
public static bool IsNegativeInfinity(Float16 value)
Expand Down Expand Up @@ -549,7 +550,7 @@ public static bool IsSubnormal(Float16 value)
/// <summary>
/// Compares this object to another object, returning an integer that indicates the relationship.
/// </summary>
///
///
/// <param name="obj">Object to compare to</param>
/// <returns>A value less than zero if this is less than <paramref name="obj"/>,
/// zero if this is equal to <paramref name="obj"/>, or a value greater than zero
Expand All @@ -570,7 +571,7 @@ public int CompareTo(object obj)
/// </summary>
/// <param name="other">Object to compare to</param>
/// <returns>A value less than zero if this is less than <paramref name="other"/>,
/// zero if this is equal to <paramref name="other"/>,
/// zero if this is equal to <paramref name="other"/>,
/// or a value greater than zero if this is greater than <paramref name="other"/>.</returns>
public int CompareTo(Float16 other)
{
Expand Down Expand Up @@ -864,10 +865,13 @@ private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig)
private const ushort PositiveQNaNBits = 0x7FC1;
private const ushort NegativeQNaNBits = 0xFFC1;

// Lowest finite value. It is corresponding to numeric_limits<BFloat16>::lowest() in C++.
private const ushort MinValueBits = 0xFF7F; // 1b0_11111110_1111111

private const ushort MaxValueBits = 0x7F7F; // 0b0_11111110_1111111

private const ushort EpsilonBits = 0x0080; // the smallest positive normal value
// Minimum positive normalized value. It is corresponding to numeric_limits<BFloat16>::min() in C++.
private const ushort EpsilonBits = 0x0080;

private const ushort PiBits = 0x4049; // 0b0_10000000_1001001

Expand Down Expand Up @@ -899,7 +903,7 @@ private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig)
/// <summary>
/// BFloat16 NaN
/// </summary>
public static BFloat16 NaN => new BFloat16(NegativeQNaNBits);
public static BFloat16 NaN => new BFloat16(NegativeQNaNBits); // .Net has no BFloat16. Follow Float16 style.

/// <summary>
/// BFloat16 Positive Zero
Expand All @@ -919,13 +923,13 @@ private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig)
/// <summary>
/// BFloat16 Min value
/// </summary>
public static BFloat16 MinValue => new BFloat16(MinValueBits); // 65,407
public static BFloat16 MinValue => new BFloat16(MinValueBits); // -3.38953139e38

/// <summary>
/// BFloat16 Max value
/// </summary>

public static BFloat16 MaxValue => new BFloat16(MaxValueBits); // 32,639
public static BFloat16 MaxValue => new BFloat16(MaxValueBits); // 3.38953139e38

/// <summary>
/// bfloat16 representation bits
Expand Down Expand Up @@ -1051,7 +1055,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits)
/// <summary>
/// Compares values of two BFloat16 for binary equality.
/// If either of the values is NaN, this will return false.
///
///
/// </summary>
/// <param name="left">left hand side</param>
/// <param name="right">right hand side</param>
Expand Down Expand Up @@ -1102,7 +1106,7 @@ public static bool IsInfinity(BFloat16 value)
/// <summary>
/// Determines whether the specified value is NaN.
/// </summary>
///
///
/// <param name="value">BFloat16 instance</param>
/// <returns>true if the value is not a number</returns>
public static bool IsNaN(BFloat16 value)
Expand All @@ -1123,7 +1127,7 @@ public static bool IsNegative(BFloat16 value)
/// <summary>
/// Determines whether the specified value is negative infinity.
/// </summary>
///
///
/// <param name="value">BFloat16 instance</param>
/// <returns>true if the value is negative infinity</returns>
public static bool IsNegativeInfinity(BFloat16 value)
Expand Down Expand Up @@ -1170,7 +1174,7 @@ public static bool IsSubnormal(BFloat16 value)
/// <summary>
/// Compares this object to another object, returning an integer that indicates the relationship.
/// </summary>
///
///
/// <param name="obj">Object to compare to</param>
/// <returns>A value less than zero if this is less than <paramref name="obj"/>,
/// zero if this is equal to <paramref name="obj"/>, or a value greater than zero
Expand All @@ -1191,7 +1195,7 @@ public int CompareTo(object obj)
/// </summary>
/// <param name="other">Object to compare to</param>
/// <returns>A value less than zero if this is less than <paramref name="other"/>,
/// zero if this is equal to <paramref name="other"/>,
/// zero if this is equal to <paramref name="other"/>,
/// or a value greater than zero if this is greater than <paramref name="other"/>.</returns>
public int CompareTo(BFloat16 other)
{
Expand Down Expand Up @@ -1368,4 +1372,4 @@ private static uint StripSign(BFloat16 value)

#endregion
}
}
}
Loading

0 comments on commit 7880342

Please sign in to comment.