Skip to content

Commit

Permalink
Add to/from E8M0 scale MX format conversion. (#131)
Browse files Browse the repository at this point in the history
* Add to/from E8M0 scale MX format conversion.

Implementation using bitmasking & shifting, so hopefully decently fast!

* Pinning JAX to <0.4.31 until sharding parameter bug is solved.

* Fixing backward compatibility with JAX 0.3.16.

`ml_dtypes` and JAX bfloat16 dtypes not equivalent older JAX versions.
  • Loading branch information
balancap authored Aug 12, 2024
1 parent 6964cd9 commit 32f81b9
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 5 deletions.
29 changes: 28 additions & 1 deletion jax_scalify/core/pow2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from functools import partial
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import jax.numpy as jnp
import ml_dtypes
import numpy as np
from jax import core
from jax.interpreters import mlir
Expand All @@ -14,6 +16,13 @@

# Exponent bits masking.
_exponent_bits_mask: Dict[Any, NDArray[Any]] = {
np.dtype(jnp.bfloat16): np.packbits(
np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=np.uint8)
).view(np.int16),
# Copy for ml_dtypes.bfloat16, distinct in older JAX versions.
np.dtype(ml_dtypes.bfloat16): np.packbits(
np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=np.uint8)
).view(np.int16),
np.dtype(np.float16): np.packbits(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], dtype=np.uint8)).view(
np.int16
),
Expand All @@ -31,6 +40,24 @@
"""


def dtype_exponent_mask(dtype: DTypeLike, sign_bit: bool = False) -> NDArray[Any]:
"""Get the exponent mask for a given Numpy/JAX dtype.
Args:
dtype: Numpy/JAX dtype.
sign_bit: Include sign bit in the mask.
Returns:
Array mask as integer dtype.
"""
mask = _exponent_bits_mask[dtype]
if sign_bit:
# Negative value to add sign.
intdtype = mask.dtype
mask = (-mask.view(dtype)).view(intdtype)
return mask
return mask


def pow2_decompose_round_down_impl(vin: Array, scale_dtype: DTypeLike) -> Array:
"""Pow-2 decompose with rounding down.
Expand All @@ -42,7 +69,7 @@ def pow2_decompose_round_down_impl(vin: Array, scale_dtype: DTypeLike) -> Array:
# NOTE: `jnp.frexp` is buggy for subnormals.
dtype = np.dtype(np.float32)
minval = np.finfo(dtype).smallest_normal
exponent_mask = _exponent_bits_mask[dtype]
exponent_mask = dtype_exponent_mask(dtype)
intdtype = exponent_mask.dtype
val = vin.astype(dtype)
# Masking mantissa bits, keeping only the exponents ones.
Expand Down
2 changes: 2 additions & 0 deletions jax_scalify/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
from .scale import as_e8m0 # noqa: F401
63 changes: 63 additions & 0 deletions jax_scalify/quantization/scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
import jax.numpy as jnp
import ml_dtypes
import numpy as np

from jax_scalify.core import Array, DTypeLike, get_numpy_api
from jax_scalify.core.pow2 import dtype_exponent_mask


def pow2_truncate(arr: Array) -> Array:
"""Convert an Array to a power of 2, using mantissa truncation.
NOTE: all sub-normals values are flushed to zero.
"""
np_api = get_numpy_api(arr)
# Masking mantissa & sign-bit, keeping only exponent values.
exponent_mask = dtype_exponent_mask(arr.dtype, sign_bit=True)
intdtype = exponent_mask.dtype
# Masking mantissa bits, keeping only the exponents ones.
arr_pow2 = np_api.bitwise_and(arr.view(intdtype), exponent_mask).view(arr.dtype).reshape(arr.shape)
return arr_pow2


def as_e8m0(arr: Array) -> Array:
"""Convert an Array to e8m0 format (i.e. power of two values).
This function is only implementing a truncation + saturation variant, in line with
the MX OCP format.
Args:
arr: Input array (FP16, FP32 or BF16).
Returns:
E8M0 array (as uint8).
"""
np_api = get_numpy_api(arr)
# assert len(arr.shape) < 2
assert arr.dtype in {np.dtype(jnp.bfloat16), np.dtype(ml_dtypes.bfloat16), np.dtype(jnp.float32)}
# Saturation => negative values saturating to min value (i.e. zero bits) in E8M0.
arr = np_api.maximum(arr, np.array(0, arr.dtype))
arr = pow2_truncate(arr)

# Bit masking to extract the exponent as uint8 array.
arr_u8 = arr.view(np.uint8).reshape((*arr.shape, -1))
arr_e8m0 = np_api.bitwise_or(np_api.left_shift(arr_u8[..., -1], 1), np_api.right_shift(arr_u8[..., -2], 7))
return arr_e8m0


def from_e8m0(arr: Array, dtype: DTypeLike) -> Array:
"""Convert an Array of e8m0 values (i.e. power of two values) to a given dtype.
Args:
arr: E8M0 array (assuming uint8 storage dtype).
dtype: Output dtype. FP32 or BF16 supported.
Returns:
Converted output.
"""
np_api = get_numpy_api(arr)
assert arr.dtype == np.uint8
assert np.dtype(dtype) in {np.dtype(jnp.bfloat16), np.dtype(ml_dtypes.bfloat16), np.dtype(jnp.float32)}
# Avoid issues with 7 mantissa bits in BF16.
# TODO: more efficient implementation!
arr = np_api.exp2(arr.astype(np.float32) - 127)
return arr.astype(dtype)
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
"chex >= 0.1.6",
"jax >= 0.3.16",
"jaxlib >= 0.3.15",
"chex>=0.1.6",
"jax>=0.3.16,<0.4.31",
"jaxlib>=0.3.15",
"ml_dtypes",
"numpy >= 1.22.4"
"numpy>=1.22.4"
]
dynamic = ["version"]

Expand Down
58 changes: 58 additions & 0 deletions tests/quantization/test_scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
import chex
import ml_dtypes
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from jax_scalify.quantization.scale import as_e8m0, from_e8m0, pow2_truncate


class QuantizationScaleTests(chex.TestCase):
@parameterized.parameters(
{"dtype": np.float16},
{"dtype": np.float32},
{"dtype": ml_dtypes.bfloat16},
)
def test__pow2_truncate__proper_result(self, dtype):
vin = np.array([-2, 0, 2, 1, 9, 15]).astype(dtype)
vout = pow2_truncate(vin)
assert vout.dtype == vin.dtype
npt.assert_array_equal(vout, [-2.0, 0.0, 2.0, 1.0, 8.0, 8.0])

@parameterized.parameters(
# {"dtype": np.float16},
{"dtype": np.float32},
{"dtype": ml_dtypes.bfloat16},
)
def test__as_e8m0__positive_values(self, dtype):
vin = np.array([0.6, 2, 1, 9, 15, 127]).astype(dtype).reshape((-1, 2))
vout = as_e8m0(vin)
assert vout.dtype == np.uint8
assert vout.shape == vin.shape
npt.assert_array_equal(vout, np.log2(pow2_truncate(vin)) + 127)

@parameterized.parameters(
# {"dtype": np.float16},
{"dtype": np.float32},
{"dtype": ml_dtypes.bfloat16},
)
def test__as_e8m0__negative_values(self, dtype):
vin = np.array([-0.1, -3, 0, 2**-127]).astype(dtype)
vout = as_e8m0(vin)
assert vout.dtype == np.uint8
# NOTE: uint8(0) is the smallest positive scale in E8M0.
npt.assert_array_equal(vout, np.uint8(0))

@parameterized.parameters(
# {"dtype": np.float16},
{"dtype": np.float32},
{"dtype": ml_dtypes.bfloat16},
)
def test__from_e8m0(self, dtype):
vin = np.array([2**-127, 0.25, 1, 2, 8, 2**127.0]).astype(dtype).reshape((-1, 2))
vin_e8m0 = as_e8m0(vin)
vout = from_e8m0(vin_e8m0, dtype)
assert vin.dtype == vout.dtype
assert vout.shape == vin.shape
npt.assert_array_equal(vout, vin)

0 comments on commit 32f81b9

Please sign in to comment.