diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index bc08040..a3d3e71 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -101,7 +101,7 @@ def astype(self, dtype) -> "ScaledArray": return ScaledArray(self.data.astype(dtype), self.scale) -def make_scaled_scalar(val: Array) -> ScaledArray: +def make_scaled_scalar(val: Array, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray: """Make a scaled scalar (array), from a single value. The returned scalar will always be built such that: @@ -118,8 +118,11 @@ def make_scaled_scalar(val: Array) -> ScaledArray: val = np.float32(val) assert np.ndim(val) == 0 assert np.issubdtype(val.dtype, np.floating) + # Scale dtype to use. + # TODO: check the scale dtype? + scale_dtype = scale_dtype or val.dtype # Split mantissa and exponent in data and scale components. - scale = pow2_round_down(val) + scale = pow2_round_down(val.astype(scale_dtype)) npapi = get_numpy_api(scale) return ScaledArray(npapi.asarray(get_mantissa(val)), scale) @@ -155,8 +158,16 @@ def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npa return scaled_array_base(data, scale, dtype, npapi) -def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> Union[Array, ScaledArray]: - """ScaledArray (helper) base factory method, similar to `(j)np.array`.""" +def as_scaled_array_base( + val: Any, scale: Optional[ArrayLike] = None, scale_dtype: Optional[DTypeLike] = None +) -> Union[Array, ScaledArray]: + """ScaledArray (helper) base factory method, similar to `(j)np.array`. + + Args: + val: Value to convert to scaled array. + scale: Optional scale value. + scale_dtype: Optional (default) scale dtype. + """ if isinstance(val, ScaledArray): return val @@ -166,7 +177,7 @@ def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> Union[A if is_static_one_scale and isinstance(val, (bool, int)): return val if is_static_one_scale and isinstance(val, float): - return make_scaled_scalar(np.float32(val)) + return make_scaled_scalar(np.float32(val), scale_dtype) # Ignored dtypes by default: int and bool ignored_dtype = np.issubdtype(val.dtype, np.integer) or np.issubdtype(val.dtype, np.bool_) @@ -174,9 +185,10 @@ def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> Union[A return val # Floating point scalar if val.ndim == 0 and is_static_one_scale: - return make_scaled_scalar(val) + return make_scaled_scalar(val, scale_dtype) - scale = np.array(1, dtype=val.dtype) if scale is None else scale + scale_dtype = scale_dtype or val.dtype + scale = np.array(1, dtype=scale_dtype) if scale is None else scale if isinstance(val, (np.ndarray, Array)): if is_static_one_scale: return ScaledArray(val, scale) diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 45791f5..9add00a 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import IntEnum from functools import partial, wraps -from typing import Any, Dict, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple import jax import numpy as np @@ -15,7 +15,7 @@ ) from jax._src.util import safe_map -from .datatype import Array, DTypeLike, NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf +from .datatype import Array, DTypeLike, ScaledArray, as_scaled_array_base, is_scaled_leaf from .utils import Pow2RoundMode @@ -96,24 +96,13 @@ def _get_data(val: Any) -> Array: return val -def promote_scalar_to_scaled_array(val: Any) -> ScaledArray: +def promote_scalar_to_scaled_array(val: Any, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray: """Promote a scalar (Numpy, JAX, ...) to a Scaled Array. Note: needs to work with any input type, including JAX tracer ones. """ # Use `as_scaled_array` promotion rules. - return as_scaled_array_base(val) - - -def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray: - """Get the ScaledArray corresponding to a Numpy constant. - - Only supporting Numpy scalars at the moment. - """ - # TODO: generalized rules! - assert np.ndim(val) == 0 - assert np.issubdtype(val.dtype, np.floating) - return ScaledArray(data=np.array(1.0, dtype=val.dtype), scale=np.copy(val)) + return as_scaled_array_base(val, scale_dtype=scale_dtype) def register_scaled_op( @@ -200,6 +189,8 @@ def autoscale_jaxpr(jaxpr: core.Jaxpr, consts, *args): env: Dict[core.Var, ScaledArray] = {} # Check dtype consistency between normal and scaled modes. safe_check_dtypes: bool = False + # AutoScale config to use. + autoscale_cfg = get_autoscale_config() def read(var): if type(var) is core.Literal: @@ -209,11 +200,11 @@ def read(var): def write(var, val): env[var] = val - def promote_to_scaled_array(val): + def promote_to_scaled_array(val, scale_dtype): if isinstance(val, ScaledArray): return val elif np.ndim(val) == 0: - return promote_scalar_to_scaled_array(val) + return promote_scalar_to_scaled_array(val, scale_dtype) # No promotion rule => just return as such. return val @@ -245,7 +236,7 @@ def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Se ) else: # Using scaled primitive. Automatic promotion of inputs to scaled array, when possible. - scaled_invals = list(map(promote_to_scaled_array, invals)) + scaled_invals = list(map(lambda v: promote_to_scaled_array(v, autoscale_cfg.scale_dtype), invals)) outvals = scaled_prim_fn(*scaled_invals, **eqn.params) if not eqn.primitive.multiple_results: outvals = [outvals] diff --git a/tests/core/test_datatype.py b/tests/core/test_datatype.py index 3aef8c4..40eddff 100644 --- a/tests/core/test_datatype.py +++ b/tests/core/test_datatype.py @@ -137,6 +137,26 @@ def test__make_scaled_scalar__zero_scalar_input(self, val): assert scaled_val.shape == () assert scaled_val.dtype == val.dtype + def test__make_scaled_scalar__optional_scale_dtype(self): + val = np.float16(0.25) + scaled_val = make_scaled_scalar(val, scale_dtype=np.float32) + assert isinstance(scaled_val, ScaledArray) + assert scaled_val.dtype == val.dtype + assert scaled_val.scale.dtype == np.float32 + npt.assert_equal(np.asarray(scaled_val), val) + + @parameterized.parameters( + {"val": np.finfo(np.float16).smallest_normal}, + {"val": np.finfo(np.float16).smallest_subnormal}, + {"val": np.float16(3.123283386230469e-05)}, + ) + def test__make_scaled_scalar__fp16_subnormal_support(self, val): + # Use FP32 scale dtype, to have enough range. + # NOTE: failing in FP16! + scaled_val = make_scaled_scalar(val, scale_dtype=np.float32) + # No loss of information when converting everything to FP32. + npt.assert_equal(np.asarray(scaled_val, dtype=np.float32), np.float32(val)) + @parameterized.parameters( {"val": np.array(1.0)}, {"val": np.float32(-0.5)}, diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index e68fac9..abebc69 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -247,3 +247,20 @@ def test__autoscale_config__context_manager(self): assert isinstance(cfg, AutoScaleConfig) assert cfg.rounding_mode == Pow2RoundMode.NONE assert cfg.scale_dtype == np.float32 + + def test__autoscale_config__scale_dtype_used_in_interpreter_promotion(self): + def fn(x): + # Underflowing to zero in `autoscale` mode if scale_dtype == np.float16. + return x * 3.123283386230469e-05 + + scaled_input = scaled_array(np.array(2.0, np.float16), scale=np.float32(0.5)) + expected_output = fn(np.float16(1)) + + with AutoScaleConfig(scale_dtype=np.float32): + scaled_output = autoscale(fn)(scaled_input) + assert scaled_output.scale.dtype == np.float32 + npt.assert_equal(np.asarray(scaled_output, dtype=np.float32), expected_output) + + with AutoScaleConfig(scale_dtype=np.float16): + scaled_output = autoscale(fn)(scaled_input) + npt.assert_almost_equal(scaled_output.scale, 0)