Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve make_scaled_scalar with optional scale dtype parameter. #93

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def astype(self, dtype) -> "ScaledArray":
return ScaledArray(self.data.astype(dtype), self.scale)


def make_scaled_scalar(val: Array) -> ScaledArray:
def make_scaled_scalar(val: Array, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray:
"""Make a scaled scalar (array), from a single value.

The returned scalar will always be built such that:
Expand All @@ -118,8 +118,11 @@ def make_scaled_scalar(val: Array) -> ScaledArray:
val = np.float32(val)
assert np.ndim(val) == 0
assert np.issubdtype(val.dtype, np.floating)
# Scale dtype to use.
# TODO: check the scale dtype?
scale_dtype = scale_dtype or val.dtype
# Split mantissa and exponent in data and scale components.
scale = pow2_round_down(val)
scale = pow2_round_down(val.astype(scale_dtype))
npapi = get_numpy_api(scale)
return ScaledArray(npapi.asarray(get_mantissa(val)), scale)

Expand Down Expand Up @@ -155,8 +158,16 @@ def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npa
return scaled_array_base(data, scale, dtype, npapi)


def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> Union[Array, ScaledArray]:
"""ScaledArray (helper) base factory method, similar to `(j)np.array`."""
def as_scaled_array_base(
val: Any, scale: Optional[ArrayLike] = None, scale_dtype: Optional[DTypeLike] = None
) -> Union[Array, ScaledArray]:
"""ScaledArray (helper) base factory method, similar to `(j)np.array`.

Args:
val: Value to convert to scaled array.
scale: Optional scale value.
scale_dtype: Optional (default) scale dtype.
"""
if isinstance(val, ScaledArray):
return val

Expand All @@ -166,17 +177,18 @@ def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> Union[A
if is_static_one_scale and isinstance(val, (bool, int)):
return val
if is_static_one_scale and isinstance(val, float):
return make_scaled_scalar(np.float32(val))
return make_scaled_scalar(np.float32(val), scale_dtype)

# Ignored dtypes by default: int and bool
ignored_dtype = np.issubdtype(val.dtype, np.integer) or np.issubdtype(val.dtype, np.bool_)
if ignored_dtype:
return val
# Floating point scalar
if val.ndim == 0 and is_static_one_scale:
return make_scaled_scalar(val)
return make_scaled_scalar(val, scale_dtype)

scale = np.array(1, dtype=val.dtype) if scale is None else scale
scale_dtype = scale_dtype or val.dtype
scale = np.array(1, dtype=scale_dtype) if scale is None else scale
if isinstance(val, (np.ndarray, Array)):
if is_static_one_scale:
return ScaledArray(val, scale)
Expand Down
27 changes: 9 additions & 18 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from enum import IntEnum
from functools import partial, wraps
from typing import Any, Dict, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple

import jax
import numpy as np
Expand All @@ -15,7 +15,7 @@
)
from jax._src.util import safe_map

from .datatype import Array, DTypeLike, NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf
from .datatype import Array, DTypeLike, ScaledArray, as_scaled_array_base, is_scaled_leaf
from .utils import Pow2RoundMode


Expand Down Expand Up @@ -96,24 +96,13 @@ def _get_data(val: Any) -> Array:
return val


def promote_scalar_to_scaled_array(val: Any) -> ScaledArray:
def promote_scalar_to_scaled_array(val: Any, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray:
"""Promote a scalar (Numpy, JAX, ...) to a Scaled Array.

Note: needs to work with any input type, including JAX tracer ones.
"""
# Use `as_scaled_array` promotion rules.
return as_scaled_array_base(val)


def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray:
"""Get the ScaledArray corresponding to a Numpy constant.

Only supporting Numpy scalars at the moment.
"""
# TODO: generalized rules!
assert np.ndim(val) == 0
assert np.issubdtype(val.dtype, np.floating)
return ScaledArray(data=np.array(1.0, dtype=val.dtype), scale=np.copy(val))
return as_scaled_array_base(val, scale_dtype=scale_dtype)


def register_scaled_op(
Expand Down Expand Up @@ -200,6 +189,8 @@ def autoscale_jaxpr(jaxpr: core.Jaxpr, consts, *args):
env: Dict[core.Var, ScaledArray] = {}
# Check dtype consistency between normal and scaled modes.
safe_check_dtypes: bool = False
# AutoScale config to use.
autoscale_cfg = get_autoscale_config()

def read(var):
if type(var) is core.Literal:
Expand All @@ -209,11 +200,11 @@ def read(var):
def write(var, val):
env[var] = val

def promote_to_scaled_array(val):
def promote_to_scaled_array(val, scale_dtype):
if isinstance(val, ScaledArray):
return val
elif np.ndim(val) == 0:
return promote_scalar_to_scaled_array(val)
return promote_scalar_to_scaled_array(val, scale_dtype)
# No promotion rule => just return as such.
return val

Expand Down Expand Up @@ -245,7 +236,7 @@ def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Se
)
else:
# Using scaled primitive. Automatic promotion of inputs to scaled array, when possible.
scaled_invals = list(map(promote_to_scaled_array, invals))
scaled_invals = list(map(lambda v: promote_to_scaled_array(v, autoscale_cfg.scale_dtype), invals))
outvals = scaled_prim_fn(*scaled_invals, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
Expand Down
20 changes: 20 additions & 0 deletions tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,26 @@ def test__make_scaled_scalar__zero_scalar_input(self, val):
assert scaled_val.shape == ()
assert scaled_val.dtype == val.dtype

def test__make_scaled_scalar__optional_scale_dtype(self):
val = np.float16(0.25)
scaled_val = make_scaled_scalar(val, scale_dtype=np.float32)
assert isinstance(scaled_val, ScaledArray)
assert scaled_val.dtype == val.dtype
assert scaled_val.scale.dtype == np.float32
npt.assert_equal(np.asarray(scaled_val), val)

@parameterized.parameters(
{"val": np.finfo(np.float16).smallest_normal},
{"val": np.finfo(np.float16).smallest_subnormal},
{"val": np.float16(3.123283386230469e-05)},
)
def test__make_scaled_scalar__fp16_subnormal_support(self, val):
# Use FP32 scale dtype, to have enough range.
# NOTE: failing in FP16!
scaled_val = make_scaled_scalar(val, scale_dtype=np.float32)
# No loss of information when converting everything to FP32.
npt.assert_equal(np.asarray(scaled_val, dtype=np.float32), np.float32(val))

@parameterized.parameters(
{"val": np.array(1.0)},
{"val": np.float32(-0.5)},
Expand Down
17 changes: 17 additions & 0 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,20 @@ def test__autoscale_config__context_manager(self):
assert isinstance(cfg, AutoScaleConfig)
assert cfg.rounding_mode == Pow2RoundMode.NONE
assert cfg.scale_dtype == np.float32

def test__autoscale_config__scale_dtype_used_in_interpreter_promotion(self):
def fn(x):
# Underflowing to zero in `autoscale` mode if scale_dtype == np.float16.
return x * 3.123283386230469e-05

scaled_input = scaled_array(np.array(2.0, np.float16), scale=np.float32(0.5))
expected_output = fn(np.float16(1))

with AutoScaleConfig(scale_dtype=np.float32):
scaled_output = autoscale(fn)(scaled_input)
assert scaled_output.scale.dtype == np.float32
npt.assert_equal(np.asarray(scaled_output, dtype=np.float32), expected_output)

with AutoScaleConfig(scale_dtype=np.float16):
scaled_output = autoscale(fn)(scaled_input)
npt.assert_almost_equal(scaled_output.scale, 0)
Loading