Skip to content

Commit

Permalink
Use power-of-two scaling in autoscale scaled translation ops rules.
Browse files Browse the repository at this point in the history
As shown in #60 issue, propagating non power-of-two scaling factors can decrease training accuracy in low precision (typically in FP16).
The additional rescaling operations will introduce non-negligible floating point accumulated errors.

Ths PR is adding the option to round the scale to a power-of-two in scaled translation. Supporting at the moment only rounding up and down. The rounding mode
can be modified in the config dataclass `AutoScaleConfig`. Scaled translations updated are: `dot_general`, `add`, `sub` and `reduce_sum`.

Finally, when implicitely converting scalars to scaled arrays, the method `make_scaled_scaled` now splits the input mantissa and exponent between data and scale.
  • Loading branch information
balancap committed Jan 3, 2024
1 parent 1c4047c commit 81b876c
Show file tree
Hide file tree
Showing 13 changed files with 315 additions and 22 deletions.
11 changes: 10 additions & 1 deletion jax_scaled_arithmetics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from . import lax
from ._version import __version__
from .core import ScaledArray, as_scaled_array, asarray, autoscale, debug_callback, scaled_array # noqa: F401
from .core import ( # noqa: F401
AutoScaleConfig,
Pow2RoundMode,
ScaledArray,
as_scaled_array,
asarray,
autoscale,
debug_callback,
scaled_array,
)
4 changes: 4 additions & 0 deletions jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@
is_scaled_leaf,
is_static_one_scalar,
is_static_zero,
make_scaled_scalar,
scaled_array,
)
from .debug import debug_callback # noqa: F401
from .interpreters import ( # noqa: F401
AutoScaleConfig,
ScaledPrimitiveType,
autoscale,
find_registered_scaled_op,
get_autoscale_config,
register_scaled_lax_op,
register_scaled_op,
)
from .typing import Array, ArrayTypes, get_numpy_api # noqa: F401
from .utils import Pow2RoundMode, pow2_round, pow2_round_down, pow2_round_up # noqa: F401
31 changes: 28 additions & 3 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from jax.tree_util import register_pytree_node_class
from numpy.typing import ArrayLike, DTypeLike, NDArray

from .typing import Array, ArrayTypes
from .typing import Array, ArrayTypes, get_numpy_api
from .utils import get_mantissa, pow2_round_down

GenericArray = Union[Array, np.ndarray]

Expand Down Expand Up @@ -42,6 +43,7 @@ class ScaledArray:
scale: GenericArray

def __post_init__(self):
# TODO/FIXME: support number as data?
assert isinstance(self.data, (*ArrayTypes, np.ndarray))
assert isinstance(self.scale, (*ArrayTypes, np.ndarray, np.number))
# Only supporting scale scalar for now.
Expand Down Expand Up @@ -93,6 +95,29 @@ def aval(self) -> ShapedArray:
return ShapedArray(self.data.shape, self.data.dtype)


def make_scaled_scalar(val: Array) -> ScaledArray:
"""Make a scaled scalar (array), from a single value.
The returned scalar will always be built such that:
- data is scalar in [1, 2)
- scale is a power-of-2 value.
NOTE: data is chosen in [1, 2) instead of [0, 1) in order to
keep any value representable in the same dtype, without overflowing.
NOTE bis: only supporting floating point input.
"""
# FIXME: implicit conversion from float64 to float32???
if isinstance(val, float):
val = np.float32(val)
assert np.ndim(val) == 0
assert np.issubdtype(val.dtype, np.floating)
# Split mantissa and exponent in data and scale components.
scale = pow2_round_down(val)
npapi = get_numpy_api(scale)
return ScaledArray(npapi.asarray(get_mantissa(val)), scale)


def is_scaled_leaf(val: Any) -> bool:
"""Is input a JAX PyTree (scaled) leaf, including ScaledArray.
Expand Down Expand Up @@ -135,15 +160,15 @@ def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> Union[A
if is_static_one_scale and isinstance(val, (bool, int)):
return val
if is_static_one_scale and isinstance(val, float):
return ScaledArray(np.array(1, dtype=np.float32), np.float32(val))
return make_scaled_scalar(np.float32(val))

# Ignored dtypes by default: int and bool
ignored_dtype = np.issubdtype(val.dtype, np.integer) or np.issubdtype(val.dtype, np.bool_)
if ignored_dtype:
return val
# Floating point scalar
if val.ndim == 0 and is_static_one_scale:
return ScaledArray(np.array(1, dtype=val.dtype), val)
return make_scaled_scalar(val)

scale = np.array(1, dtype=val.dtype) if scale is None else scale
if isinstance(val, (np.ndarray, Array)):
Expand Down
30 changes: 30 additions & 0 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from dataclasses import dataclass
from enum import IntEnum
from functools import partial, wraps
from typing import Any, Dict, Sequence, Tuple
Expand All @@ -15,6 +16,35 @@
from jax._src.util import safe_map

from .datatype import NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf
from .utils import Pow2RoundMode


@dataclass(frozen=True)
class AutoScaleConfig:
"""AutoScale configuration/parameters when tracing a graph.
NOTE: this config can be locally changed using a Python context manager:
`with AutoScaleConfig(...):`
"""

rounding_mode: Pow2RoundMode = Pow2RoundMode.DOWN

def __enter__(self):
global _autoscale_config_stack
_autoscale_config_stack.append(self)

def __exit__(self, exc_type, exc_val, exc_tb):
global _autoscale_config_stack
_autoscale_config_stack.pop()


# AutoScale config stack.
_autoscale_config_stack = [AutoScaleConfig()]


def get_autoscale_config() -> AutoScaleConfig:
"""Get current/local autoscale config."""
return _autoscale_config_stack[-1]


class ScaledPrimitiveType(IntEnum):
Expand Down
9 changes: 6 additions & 3 deletions jax_scaled_arithmetics/core/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
def get_numpy_api(val: Any) -> Any:
"""Get the Numpy API corresponding to an array.
Using the NumPy API whenever possible when tracing a JAX graph
allows for simple constant folding optimization.
JAX or classic Numpy supported.
"""
if isinstance(val, jax.Array):
return jnp
elif isinstance(val, (np.ndarray, np.number)):
if isinstance(val, (np.ndarray, np.number)):
return np
if isinstance(val, ArrayTypes):
return jnp
raise NotImplementedError(f"Unsupported input type '{type(val)}'. No matching Numpy API.")
79 changes: 79 additions & 0 deletions jax_scaled_arithmetics/core/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from enum import IntEnum
from typing import Any, Dict

import numpy as np
from numpy.typing import NDArray

from .typing import Array, get_numpy_api

# Exponent bits masking.
_exponent_bits_mask: Dict[Any, NDArray[Any]] = {
np.dtype(np.float16): np.packbits(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], dtype=np.uint8)).view(
np.int16
),
np.dtype(np.float32): np.packbits(
np.array(
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
dtype=np.uint8,
)
).view(np.int32),
np.dtype(np.float64): np.array(np.inf, np.float64).view(np.int64),
}
"""Exponents bit masking: explicit bitmask to keep only exponent bits in floating point values.
NOTE: normally should also correspond to `np.inf` value for FP16 and FP32.
"""


class Pow2RoundMode(IntEnum):
"""Power-of-two supported rounded mode."""

NONE = 0
DOWN = 1
UP = 2
STOCHASTIC = 3


def get_mantissa(val: Array) -> Array:
"""Extract the mantissa of an array, masking the exponent.
Similar to `numpy.frexp`, but with implicit bit to be consistent with
`pow2_round_down`.
"""
np_api = get_numpy_api(val)
# TODO: implement using bitmasking?
mantissa_val, _ = np_api.frexp(val)
# Re-add the implicit bit to be consistent with `pow2_round_down`
mantissa_val = mantissa_val * np.array(2, dtype=val.dtype)
return mantissa_val


def pow2_round_down(val: Array) -> Array:
"""Round down to the closest power of 2."""
np_api = get_numpy_api(val)
exponent_mask = _exponent_bits_mask[val.dtype]
intdtype = exponent_mask.dtype
pow2_val = np_api.bitwise_and(val.view(intdtype), exponent_mask).view(val.dtype).reshape(val.shape)
return pow2_val


def pow2_round_up(val: Array) -> Array:
"""Round up to the closest power of 2.
NOTE: may overflow to inf.
"""
# FIXME: rounding when already a power of 2.
# Should do additional masking to check that.
pow2_val = pow2_round_down(val) * np.array(2, dtype=val.dtype)
return pow2_val


def pow2_round(val: Array, mode: Pow2RoundMode = Pow2RoundMode.DOWN) -> Array:
"""Power-of-two rounding."""
if mode == Pow2RoundMode.NONE:
return val
elif mode == Pow2RoundMode.DOWN:
return pow2_round_down(val)
elif mode == Pow2RoundMode.UP:
return pow2_round_up(val)
raise NotImplementedError(f"Unsupported power-of-2 rounding mode '{mode}'.")
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def scaled_reduce_precision(A: ScaledArray, exponent_bits: int, mantissa_bits: i
def scaled_concatenate(operands: Sequence[ScaledArray], dimension: int) -> ScaledArray:
# TODO: inputs checking (dtype and cie).
scales = jnp.array([v.scale for v in operands])
# Max rescaling of the collection of operands.
# Max rescaling of the collection of operands. Preserving pow2 scaling.
# TODO: explore alternative strategies?
outdtype = operands[0].dtype
scale_max = jnp.max(scales)
Expand Down
24 changes: 21 additions & 3 deletions jax_scaled_arithmetics/lax/scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
from jax._src.ad_util import add_any_p

from jax_scaled_arithmetics import core
from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, as_scaled_array, register_scaled_op
from jax_scaled_arithmetics.core import (
DTypeLike,
ScaledArray,
as_scaled_array,
get_autoscale_config,
pow2_round,
register_scaled_op,
)

from .scaled_ops_common import check_scalar_scales, promote_scale_types

Expand All @@ -19,12 +26,16 @@ def scaled_add_sub(A: ScaledArray, B: ScaledArray, binary_op: Any) -> ScaledArra
check_scalar_scales(A, B)
A, B = promote_scale_types(A, B)
assert np.issubdtype(A.scale.dtype, np.floating)
# Pow2 rounding for unit scaling "rule".
pow2_rounding_mode = get_autoscale_config().rounding_mode
# TODO: what happens to `sqrt` for non-floating scale?
# More stable than direct L2 norm, to avoid scale overflow.
ABscale_max = lax.max(A.scale, B.scale)
ABscale_min = lax.min(A.scale, B.scale)
ABscale_ratio = ABscale_min / ABscale_max
output_scale = ABscale_max * lax.sqrt(1 + ABscale_ratio * ABscale_ratio)
# Transform back to power-of-2
output_scale = pow2_round(output_scale, pow2_rounding_mode)
# Output dtype => promotion of A and B dtypes.
outdtype = jnp.promote_types(A.dtype, B.dtype)
Arescale = (A.scale / output_scale).astype(outdtype)
Expand Down Expand Up @@ -63,10 +74,13 @@ def scaled_dot_general(
assert len(lhs_contracting_dims) == 1
assert len(rhs_contracting_dims) == 1

# Pow2 rounding for unit scaling "rule".
pow2_rounding_mode = get_autoscale_config().rounding_mode
contracting_dim_size = lhs.shape[lhs_contracting_dims[0]]
# "unit scaling" rule, based on the contracting axis.
outscale_dtype = jnp.promote_types(lhs.scale.dtype, rhs.scale.dtype)
contracting_rescale = np.sqrt(contracting_dim_size)
contracting_rescale = pow2_round(np.sqrt(contracting_dim_size), pow2_rounding_mode)
# Keeping power of 2 scale.
output_scale = lhs.scale * rhs.scale * contracting_rescale.astype(outscale_dtype)
# NOTE: need to be a bit careful about scale promotion?
output_data = lax.dot_general(
Expand Down Expand Up @@ -94,8 +108,11 @@ def scaled_reduce_sum(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
shape = val.shape
axes_size = np.array([shape[idx] for idx in axes])
# Rescale data component following reduction axes.
# Pow2 rounding for unit scaling "rule".
pow2_rounding_mode = get_autoscale_config().rounding_mode
# Rescale data component following reduction axes & round to power of 2 value.
axes_rescale = np.sqrt(np.prod(axes_size))
axes_rescale = pow2_round(axes_rescale, pow2_rounding_mode)
data = lax.reduce_sum_p.bind(val.data, axes=axes) / axes_rescale.astype(val.data.dtype)
outscale = val.scale * axes_rescale.astype(val.scale.dtype)
return ScaledArray(data, outscale)
Expand All @@ -107,6 +124,7 @@ def scaled_reduce_prod(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
shape = val.shape
data = lax.reduce_prod_p.bind(val.data, axes=axes)
axes_size = np.prod(np.array([shape[idx] for idx in axes]))
# Stable for power of 2.
scale = lax.integer_pow(val.scale, axes_size)
return ScaledArray(data, scale)

Expand Down
56 changes: 54 additions & 2 deletions tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
is_scaled_leaf,
is_static_one_scalar,
is_static_zero,
make_scaled_scalar,
pow2_round_down,
scaled_array,
)

Expand Down Expand Up @@ -101,6 +103,56 @@ def test__scaled_array__numpy_array_interface(self, npapi):
assert isinstance(out, np.ndarray)
npt.assert_array_equal(out, sarr.data * sarr.scale)

@parameterized.parameters(
{"val": 0.25},
)
def test__make_scaled_scalar__float_input(self, val):
scaled_val = make_scaled_scalar(val)
assert isinstance(scaled_val, ScaledArray)
assert scaled_val.shape == ()
assert scaled_val.data.dtype == np.float32
assert scaled_val.scale.dtype == np.float32
npt.assert_equal(np.asarray(scaled_val), val)
assert isinstance(scaled_val.data, (np.ndarray, np.number))
assert isinstance(scaled_val.scale, (np.ndarray, np.number))

@parameterized.parameters(
{"val": np.float16(0)},
{"val": np.float32(0)},
# {"val": np.float64(0)}, FIXME!
)
def test__make_scaled_scalar__zero_scalar_input(self, val):
scaled_val = make_scaled_scalar(val)
assert isinstance(scaled_val, ScaledArray)
assert scaled_val.shape == ()
assert scaled_val.dtype == val.dtype

@parameterized.parameters(
{"val": np.array(1.0)},
{"val": np.float32(-0.5)},
{"val": np.float16(1.25)},
{"val": np.float32(-65504)},
# Testing JAX arrays too.
{"val": jnp.float32(0.5)},
{"val": jnp.float16(1.25)},
)
def test__make_scaled_scalar__proper_split_data_scale(self, val):
scaled_val = make_scaled_scalar(val)
assert isinstance(scaled_val, ScaledArray)
assert scaled_val.shape == ()
assert scaled_val.data.dtype == val.dtype
assert scaled_val.scale.dtype == val.dtype
npt.assert_equal(np.asarray(scaled_val), val)
npt.assert_equal(scaled_val.scale, pow2_round_down(val))
npt.assert_array_less(0, scaled_val.scale)
# Make sure we are not doing implicit conversion to JAX arrays.
if isinstance(val, Array):
assert isinstance(scaled_val.data, Array)
assert isinstance(scaled_val.scale, Array)
else:
assert isinstance(scaled_val.data, (np.ndarray, np.number))
assert isinstance(scaled_val.scale, (np.ndarray, np.number))

def test__is_scaled_leaf__consistent_with_jax(self):
assert is_scaled_leaf(8)
assert is_scaled_leaf(2.0)
Expand Down Expand Up @@ -134,8 +186,8 @@ def test__as_scaled_array__float_scalar(self, data):
output = as_scaled_array(data)
assert isinstance(output, ScaledArray)
assert output.data.dtype == output.scale.dtype
npt.assert_array_almost_equal(output.data, 1)
npt.assert_array_almost_equal(output.scale, data)
# NOTE: for scalars, data always in [1, 2)
npt.assert_almost_equal(np.asarray(output), data)

@parameterized.parameters(
{"data": jnp.float32(3.0)},
Expand Down
Loading

0 comments on commit 81b876c

Please sign in to comment.