-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add to/from E8M0 scale MX format conversion. (#131)
* Add to/from E8M0 scale MX format conversion. Implementation using bitmasking & shifting, so hopefully decently fast! * Pinning JAX to <0.4.31 until sharding parameter bug is solved. * Fixing backward compatibility with JAX 0.3.16. `ml_dtypes` and JAX bfloat16 dtypes not equivalent older JAX versions.
- Loading branch information
Showing
5 changed files
with
155 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (c) 2024 Graphcore Ltd. All rights reserved. | ||
from .scale import as_e8m0 # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) 2024 Graphcore Ltd. All rights reserved. | ||
import jax.numpy as jnp | ||
import ml_dtypes | ||
import numpy as np | ||
|
||
from jax_scalify.core import Array, DTypeLike, get_numpy_api | ||
from jax_scalify.core.pow2 import dtype_exponent_mask | ||
|
||
|
||
def pow2_truncate(arr: Array) -> Array: | ||
"""Convert an Array to a power of 2, using mantissa truncation. | ||
NOTE: all sub-normals values are flushed to zero. | ||
""" | ||
np_api = get_numpy_api(arr) | ||
# Masking mantissa & sign-bit, keeping only exponent values. | ||
exponent_mask = dtype_exponent_mask(arr.dtype, sign_bit=True) | ||
intdtype = exponent_mask.dtype | ||
# Masking mantissa bits, keeping only the exponents ones. | ||
arr_pow2 = np_api.bitwise_and(arr.view(intdtype), exponent_mask).view(arr.dtype).reshape(arr.shape) | ||
return arr_pow2 | ||
|
||
|
||
def as_e8m0(arr: Array) -> Array: | ||
"""Convert an Array to e8m0 format (i.e. power of two values). | ||
This function is only implementing a truncation + saturation variant, in line with | ||
the MX OCP format. | ||
Args: | ||
arr: Input array (FP16, FP32 or BF16). | ||
Returns: | ||
E8M0 array (as uint8). | ||
""" | ||
np_api = get_numpy_api(arr) | ||
# assert len(arr.shape) < 2 | ||
assert arr.dtype in {np.dtype(jnp.bfloat16), np.dtype(ml_dtypes.bfloat16), np.dtype(jnp.float32)} | ||
# Saturation => negative values saturating to min value (i.e. zero bits) in E8M0. | ||
arr = np_api.maximum(arr, np.array(0, arr.dtype)) | ||
arr = pow2_truncate(arr) | ||
|
||
# Bit masking to extract the exponent as uint8 array. | ||
arr_u8 = arr.view(np.uint8).reshape((*arr.shape, -1)) | ||
arr_e8m0 = np_api.bitwise_or(np_api.left_shift(arr_u8[..., -1], 1), np_api.right_shift(arr_u8[..., -2], 7)) | ||
return arr_e8m0 | ||
|
||
|
||
def from_e8m0(arr: Array, dtype: DTypeLike) -> Array: | ||
"""Convert an Array of e8m0 values (i.e. power of two values) to a given dtype. | ||
Args: | ||
arr: E8M0 array (assuming uint8 storage dtype). | ||
dtype: Output dtype. FP32 or BF16 supported. | ||
Returns: | ||
Converted output. | ||
""" | ||
np_api = get_numpy_api(arr) | ||
assert arr.dtype == np.uint8 | ||
assert np.dtype(dtype) in {np.dtype(jnp.bfloat16), np.dtype(ml_dtypes.bfloat16), np.dtype(jnp.float32)} | ||
# Avoid issues with 7 mantissa bits in BF16. | ||
# TODO: more efficient implementation! | ||
arr = np_api.exp2(arr.astype(np.float32) - 127) | ||
return arr.astype(dtype) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (c) 2024 Graphcore Ltd. All rights reserved. | ||
import chex | ||
import ml_dtypes | ||
import numpy as np | ||
import numpy.testing as npt | ||
from absl.testing import parameterized | ||
|
||
from jax_scalify.quantization.scale import as_e8m0, from_e8m0, pow2_truncate | ||
|
||
|
||
class QuantizationScaleTests(chex.TestCase): | ||
@parameterized.parameters( | ||
{"dtype": np.float16}, | ||
{"dtype": np.float32}, | ||
{"dtype": ml_dtypes.bfloat16}, | ||
) | ||
def test__pow2_truncate__proper_result(self, dtype): | ||
vin = np.array([-2, 0, 2, 1, 9, 15]).astype(dtype) | ||
vout = pow2_truncate(vin) | ||
assert vout.dtype == vin.dtype | ||
npt.assert_array_equal(vout, [-2.0, 0.0, 2.0, 1.0, 8.0, 8.0]) | ||
|
||
@parameterized.parameters( | ||
# {"dtype": np.float16}, | ||
{"dtype": np.float32}, | ||
{"dtype": ml_dtypes.bfloat16}, | ||
) | ||
def test__as_e8m0__positive_values(self, dtype): | ||
vin = np.array([0.6, 2, 1, 9, 15, 127]).astype(dtype).reshape((-1, 2)) | ||
vout = as_e8m0(vin) | ||
assert vout.dtype == np.uint8 | ||
assert vout.shape == vin.shape | ||
npt.assert_array_equal(vout, np.log2(pow2_truncate(vin)) + 127) | ||
|
||
@parameterized.parameters( | ||
# {"dtype": np.float16}, | ||
{"dtype": np.float32}, | ||
{"dtype": ml_dtypes.bfloat16}, | ||
) | ||
def test__as_e8m0__negative_values(self, dtype): | ||
vin = np.array([-0.1, -3, 0, 2**-127]).astype(dtype) | ||
vout = as_e8m0(vin) | ||
assert vout.dtype == np.uint8 | ||
# NOTE: uint8(0) is the smallest positive scale in E8M0. | ||
npt.assert_array_equal(vout, np.uint8(0)) | ||
|
||
@parameterized.parameters( | ||
# {"dtype": np.float16}, | ||
{"dtype": np.float32}, | ||
{"dtype": ml_dtypes.bfloat16}, | ||
) | ||
def test__from_e8m0(self, dtype): | ||
vin = np.array([2**-127, 0.25, 1, 2, 8, 2**127.0]).astype(dtype).reshape((-1, 2)) | ||
vin_e8m0 = as_e8m0(vin) | ||
vout = from_e8m0(vin_e8m0, dtype) | ||
assert vin.dtype == vout.dtype | ||
assert vout.shape == vin.shape | ||
npt.assert_array_equal(vout, vin) |