Skip to content

Commit

Permalink
Implement JAX pow2_decompose primitive.
Browse files Browse the repository at this point in the history
The primitive `pow2_decompose` is the core decomposition kernel used everywhere in AutoScale/Scalify,
meaning it is worth properly formalizing it as a JAX primitive, simplifying the Jaxpr level graph
and allowing proper custom kernel optimization on different HW platforms (GPU, IPU, TPU, ...).

NOTE: this PR is fixing additional subnormal related bugs, due to inconsistency of jnp.frexp vs Numpy.
See: jax-ml/jax#19689
  • Loading branch information
balancap committed Feb 9, 2024
1 parent 638e9f9 commit 19904d3
Show file tree
Hide file tree
Showing 15 changed files with 355 additions and 165 deletions.
2 changes: 1 addition & 1 deletion experiments/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def predict(params, inputs):
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
targets = jsa.lax.rebalance(targets, np.float32(1 / 16))
targets = jsa.lax.rebalance(targets, np.float32(1 / 8))
return -jnp.mean(jnp.sum(preds * targets, axis=1))


Expand Down
3 changes: 2 additions & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
register_scaled_lax_op,
register_scaled_op,
)
from .pow2 import Pow2RoundMode, pow2_decompose, pow2_round, pow2_round_down, pow2_round_up # noqa: F401
from .typing import Array, ArrayTypes, get_numpy_api # noqa: F401
from .utils import Pow2RoundMode, pow2_round, pow2_round_down, pow2_round_up, safe_div, safe_reciprocal # noqa: F401
from .utils import safe_div, safe_reciprocal # noqa: F401
12 changes: 5 additions & 7 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from jax.tree_util import register_pytree_node_class
from numpy.typing import ArrayLike, DTypeLike, NDArray

from .typing import Array, ArrayTypes, get_numpy_api
from .utils import get_mantissa, pow2_round_down
from .pow2 import Pow2RoundMode, pow2_decompose
from .typing import Array, ArrayTypes

GenericArray = Union[Array, np.ndarray]

Expand Down Expand Up @@ -121,13 +121,11 @@ def make_scaled_scalar(val: Array, scale_dtype: Optional[DTypeLike] = None) -> S
val = np.float32(val)
assert np.ndim(val) == 0
assert np.issubdtype(val.dtype, np.floating)
# Scale dtype to use.
# TODO: check the scale dtype?
# Scale dtype to use. TODO: check the scale dtype is valid?
scale_dtype = scale_dtype or val.dtype
# Split mantissa and exponent in data and scale components.
scale = pow2_round_down(val.astype(scale_dtype))
npapi = get_numpy_api(scale)
return ScaledArray(npapi.asarray(get_mantissa(val)), scale)
scale, mantissa = pow2_decompose(val, scale_dtype=scale_dtype, mode=Pow2RoundMode.DOWN)
return ScaledArray(mantissa, scale)


def is_scaled_leaf(val: Any) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
is_scaled_leaf,
is_static_zero,
)
from .utils import Pow2RoundMode, python_scalar_as_numpy
from .pow2 import Pow2RoundMode
from .utils import python_scalar_as_numpy


@dataclass(frozen=True)
Expand Down
163 changes: 163 additions & 0 deletions jax_scaled_arithmetics/core/pow2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import logging
from enum import IntEnum
from functools import partial
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import numpy as np
from jax import core
from jax.interpreters import mlir
from jax.interpreters.mlir import LoweringRuleContext, ir
from numpy.typing import DTypeLike, NDArray

from .typing import Array, get_numpy_api

# Exponent bits masking.
_exponent_bits_mask: Dict[Any, NDArray[Any]] = {
np.dtype(np.float16): np.packbits(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], dtype=np.uint8)).view(
np.int16
),
np.dtype(np.float32): np.packbits(
np.array(
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
dtype=np.uint8,
)
).view(np.int32),
np.dtype(np.float64): np.array(np.inf, np.float64).view(np.int64),
}
"""Exponents bit masking: explicit bitmask to keep only exponent bits in floating point values.
NOTE: normally should also correspond to `np.inf` value for FP16 and FP32.
"""


def pow2_decompose_round_down_impl(vin: Array, scale_dtype: DTypeLike) -> Array:
"""Pow-2 decompose with rounding down.
Returns:
(scale, vout) such that vin = scale * vout
"""
np_api = get_numpy_api(vin)
# Perform all computations in FP32, to support FP16 submormals.
# NOTE: `jnp.frexp` is buggy for subnormals.
dtype = np.dtype(np.float32)
minval = np.finfo(dtype).smallest_normal
exponent_mask = _exponent_bits_mask[dtype]
intdtype = exponent_mask.dtype
val = vin.astype(dtype)
# Masking mantissa bits, keeping only the exponents ones.
scale_pow2 = np_api.bitwise_and(val.view(intdtype), exponent_mask).view(val.dtype).reshape(val.shape)
# Get the mantissa in float32. Make sure we don't divide by zero, and handle nan/inf.
normal_scale_val = np_api.logical_and(np_api.isfinite(scale_pow2), scale_pow2 != 0)
scale_renorm = np_api.where(normal_scale_val, scale_pow2, minval)
mantissa = val / scale_renorm
return scale_pow2.astype(scale_dtype), mantissa.astype(vin.dtype)


class Pow2RoundMode(IntEnum):
"""Power-of-two supported rounded mode."""

NONE = 0
DOWN = 1
UP = 2
STOCHASTIC = 3


pow2_decompose_p = core.Primitive("pow2_decompose")
"""`pow2_decompose` pow2 decompose JAX primitive.
"""


def pow2_decompose(
vin: Array, scale_dtype: Optional[DTypeLike] = None, mode: Pow2RoundMode = Pow2RoundMode.DOWN
) -> Tuple[Array, Array]:
"""Power-2 decompose, i.e. vin = s * vout where s is a power-of 2 scaling.
Args:
vin: Input array.
scale_dtype: Scale dtype to use.
mode: Pow2 rounding.
Returns:
(scale, vout) such that vin = scale * vout
"""
scale_dtype = np.dtype(scale_dtype or vin.dtype)
# A couple of checks on dtypes.
assert np.issubdtype(vin.dtype, np.floating)
assert np.issubdtype(scale_dtype, np.floating)
if scale_dtype == np.float16:
logging.warning("`pow2_decompose` does not support FP16 sub-normals when using FP16 scale dtype.")
out = pow2_decompose_p.bind(vin, scale_dtype=scale_dtype, mode=mode)
return out


def pow2_decompose_eager_impl(
vin: Array, scale_dtype: Optional[DTypeLike] = None, mode: Pow2RoundMode = Pow2RoundMode.DOWN
) -> Tuple[Array, Array]:
"""Eager mode implementation, on JAX/Numpy arrays."""
if mode == Pow2RoundMode.DOWN:
return pow2_decompose_round_down_impl(vin, scale_dtype)
raise NotImplementedError(f"Unsupported power-of-2 rounding mode '{mode}'.")


def pow2_decompose_abstract_eval(
vin: core.ShapedArray, scale_dtype: Optional[DTypeLike] = None, mode: Pow2RoundMode = Pow2RoundMode.DOWN
) -> Tuple[core.ShapedArray, core.ShapedArray]:
scale_dtype = scale_dtype or vin.dtype
sout = core.ShapedArray(vin.shape, dtype=scale_dtype)
return (sout, vin)


def pow2_decompose_mlir_lowering(
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params
) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
scale_dtype = params["scale_dtype"]
mode = params["mode"]
pow2_decompose_fn = partial(pow2_decompose_eager_impl, scale_dtype=scale_dtype, mode=mode)
outputs = mlir.lower_fun(pow2_decompose_fn, multiple_results=True)(ctx, *args)
return outputs


# Register as standard JAX primitive
pow2_decompose_p.multiple_results = True
pow2_decompose_p.def_abstract_eval(pow2_decompose_abstract_eval)
pow2_decompose_p.def_impl(pow2_decompose_eager_impl)
# Default lowering on GPU, TPU, ...
mlir.register_lowering(pow2_decompose_p, pow2_decompose_mlir_lowering)


def pow2_round_down(val: Array) -> Array:
"""Round down to the closest power of 2."""
# Keep only the scale component of `pow2_decompose`
pow2_val, _ = pow2_decompose(val, scale_dtype=val.dtype, mode=Pow2RoundMode.DOWN)
return pow2_val


def pow2_round_up(val: Array) -> Array:
"""Round up to the closest power of 2.
NOTE: may overflow to inf.
"""
# FIXME: rounding when already a power of 2.
# Should do additional masking to check that.
pow2_val = pow2_round_down(val) * np.array(2, dtype=val.dtype)
return pow2_val


def pow2_round(val: Array, mode: Pow2RoundMode = Pow2RoundMode.DOWN) -> Array:
"""Power-of-two rounding."""
if mode == Pow2RoundMode.NONE:
return val
elif mode == Pow2RoundMode.DOWN:
return pow2_round_down(val)
elif mode == Pow2RoundMode.UP:
return pow2_round_up(val)
raise NotImplementedError(f"Unsupported power-of-2 rounding mode '{mode}'.")


def get_mantissa(val: Array) -> Array:
"""Extract the mantissa of an array, masking the exponent.
Similar to `numpy.frexp`, but with implicit bit to be consistent with
`pow2_round_down`.
"""
_, mantissa = pow2_decompose(val, scale_dtype=val.dtype, mode=Pow2RoundMode.DOWN)
return mantissa
77 changes: 2 additions & 75 deletions jax_scaled_arithmetics/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,11 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from enum import IntEnum
from typing import Any, Dict
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np
from numpy.typing import NDArray

from .typing import Array, get_numpy_api

# Exponent bits masking.
_exponent_bits_mask: Dict[Any, NDArray[Any]] = {
np.dtype(np.float16): np.packbits(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], dtype=np.uint8)).view(
np.int16
),
np.dtype(np.float32): np.packbits(
np.array(
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
dtype=np.uint8,
)
).view(np.int32),
np.dtype(np.float64): np.array(np.inf, np.float64).view(np.int64),
}
"""Exponents bit masking: explicit bitmask to keep only exponent bits in floating point values.
NOTE: normally should also correspond to `np.inf` value for FP16 and FP32.
"""


class Pow2RoundMode(IntEnum):
"""Power-of-two supported rounded mode."""

NONE = 0
DOWN = 1
UP = 2
STOCHASTIC = 3


def get_mantissa(val: Array) -> Array:
"""Extract the mantissa of an array, masking the exponent.
Similar to `numpy.frexp`, but with implicit bit to be consistent with
`pow2_round_down`.
"""
np_api = get_numpy_api(val)
# TODO: implement using bitmasking?
mantissa_val, _ = np_api.frexp(val)
# Re-add the implicit bit to be consistent with `pow2_round_down`
mantissa_val = mantissa_val * np.array(2, dtype=val.dtype)
return mantissa_val


def pow2_round_down(val: Array) -> Array:
"""Round down to the closest power of 2."""
np_api = get_numpy_api(val)
exponent_mask = _exponent_bits_mask[val.dtype]
intdtype = exponent_mask.dtype
pow2_val = np_api.bitwise_and(val.view(intdtype), exponent_mask).view(val.dtype).reshape(val.shape)
return pow2_val


def pow2_round_up(val: Array) -> Array:
"""Round up to the closest power of 2.
NOTE: may overflow to inf.
"""
# FIXME: rounding when already a power of 2.
# Should do additional masking to check that.
pow2_val = pow2_round_down(val) * np.array(2, dtype=val.dtype)
return pow2_val


def pow2_round(val: Array, mode: Pow2RoundMode = Pow2RoundMode.DOWN) -> Array:
"""Power-of-two rounding."""
if mode == Pow2RoundMode.NONE:
return val
elif mode == Pow2RoundMode.DOWN:
return pow2_round_down(val)
elif mode == Pow2RoundMode.UP:
return pow2_round_up(val)
raise NotImplementedError(f"Unsupported power-of-2 rounding mode '{mode}'.")
from .typing import Array


def safe_div(lhs: Array, rhs: Array) -> Array:
Expand Down
7 changes: 7 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ def scaled_div(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
return ScaledArray(lhs.data / rhs.data, lhs.scale / rhs.scale)


@core.register_scaled_lax_op
def scaled_sign(val: ScaledArray) -> Array:
assert isinstance(val, ScaledArray)
# Just need to run `lax.sign` on main data.
return ScaledArray(lax.sign(val.data), np.array(1, dtype=val.scale.dtype))


@core.register_scaled_lax_op
def scaled_is_finite(val: ScaledArray) -> Array:
assert isinstance(val, ScaledArray)
Expand Down
8 changes: 5 additions & 3 deletions jax_scaled_arithmetics/lax/scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def scaled_dot_general(
contracting_dim_size = lhs.shape[lhs_contracting_dims[0]]
# "unit scaling" rule, based on the contracting axis.
outscale_dtype = jnp.promote_types(lhs.scale.dtype, rhs.scale.dtype)
contracting_rescale = pow2_round(np.sqrt(contracting_dim_size), pow2_rounding_mode)
contracting_rescale = np.sqrt(contracting_dim_size).astype(outscale_dtype)
contracting_rescale = pow2_round(contracting_rescale, pow2_rounding_mode)
# Keeping power of 2 scale.
output_scale = lhs.scale * rhs.scale * contracting_rescale.astype(outscale_dtype)
# NOTE: need to be a bit careful about scale promotion?
Expand Down Expand Up @@ -107,14 +108,15 @@ def scaled_conv_general_dilated(lhs: ScaledArray, rhs: ScaledArray, **params) ->
def scaled_reduce_sum(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
shape = val.shape
scale_dtype = val.scale.dtype
axes_size = np.array([shape[idx] for idx in axes])
# Pow2 rounding for unit scaling "rule".
pow2_rounding_mode = get_autoscale_config().rounding_mode
# Rescale data component following reduction axes & round to power of 2 value.
axes_rescale = np.sqrt(np.prod(axes_size))
axes_rescale = np.sqrt(np.prod(axes_size)).astype(scale_dtype)
axes_rescale = pow2_round(axes_rescale, pow2_rounding_mode)
data = lax.reduce_sum_p.bind(val.data, axes=axes) / axes_rescale.astype(val.data.dtype)
outscale = val.scale * axes_rescale.astype(val.scale.dtype)
outscale = val.scale * axes_rescale.astype(scale_dtype)
return ScaledArray(data, outscale)


Expand Down
Loading

0 comments on commit 19904d3

Please sign in to comment.