diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index cab8d14..9b00264 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -15,7 +15,7 @@ ) from jax._src.util import safe_map -from .datatype import NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf +from .datatype import DTypeLike, NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf from .utils import Pow2RoundMode @@ -25,9 +25,14 @@ class AutoScaleConfig: NOTE: this config can be locally changed using a Python context manager: `with AutoScaleConfig(...):` + + Args: + rounding_mode: Power-of-2 rounding mode. + scale_dtype: Scale (default) datatype. """ rounding_mode: Pow2RoundMode = Pow2RoundMode.DOWN + scale_dtype: DTypeLike = None def __enter__(self): global _autoscale_config_stack diff --git a/jax_scaled_arithmetics/lax/base_scaling_primitives.py b/jax_scaled_arithmetics/lax/base_scaling_primitives.py index 8aa2466..81878f8 100644 --- a/jax_scaled_arithmetics/lax/base_scaling_primitives.py +++ b/jax_scaled_arithmetics/lax/base_scaling_primitives.py @@ -1,4 +1,5 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import logging from typing import Optional, Sequence, Union import numpy as np @@ -12,6 +13,7 @@ ScaledArray, ScaledPrimitiveType, asarray, + get_autoscale_config, is_static_one_scalar, register_scaled_op, safe_div, @@ -163,6 +165,11 @@ def scaled_stop_scaling(values: ScaledArray, dtype: Optional[DTypeLike] = None) """ +def get_scale_dtype() -> Optional[DTypeLike]: + """Get the scale dtype, if set in the AutoScale config.""" + return get_autoscale_config().scale_dtype + + def get_data_scale(values: Array) -> Array: """`get_data_scale` primitive call method.""" return get_data_scale_p.bind(values) @@ -171,14 +178,19 @@ def get_data_scale(values: Array) -> Array: def get_data_scale_impl(values: Array) -> Array: if isinstance(values, ScaledArray): return (values.data, values.scale) - scale = np.ones((), dtype=values.dtype) + # Use array dtype for scale by default. + scale_dtype = get_scale_dtype() or values.dtype + scale = np.ones((), dtype=scale_dtype) return values, scale def get_data_scale_abstract_eval(values: core.ShapedArray) -> core.ShapedArray: if isinstance(values, ScaledArray): return (values.data, values.scale) - return values, core.ShapedArray((), dtype=values.dtype) + # Use array dtype for scale by default. + scale_dtype = get_scale_dtype() or values.dtype + print(scale_dtype) + return values, core.ShapedArray((), dtype=scale_dtype) def get_data_scale_mlir_lowering( @@ -186,12 +198,22 @@ def get_data_scale_mlir_lowering( ) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]: # Just forwarding `values` term, adding a constant scalar scale(1). assert len(args) == 1 - scale = ir_constant(np.ones((), dtype=ctx.avals_in[0].dtype)) + assert len(ctx.avals_in) == 1 + assert len(ctx.avals_out) == 2 + # Scale dtype "decided" during initial JAX tracing. + scale_dtype = ctx.avals_out[1].dtype + scale = ir_constant(np.ones((), dtype=scale_dtype)) return (args[0], scale) def scaled_get_data_scale(values: ScaledArray) -> Array: """Scaled `get_data_scale` implementation: return scale tensor.""" + scale_dtype = get_scale_dtype() + # Mis-match may potentially create issues (i.e. not equivalent scale dtype after autoscale tracer)! + if scale_dtype != values.scale.dtype: + logging.warning( + f"Autoscale config scale dtype not matching ScaledArray scale dtype: '{values.scale.dtype}' vs '{scale_dtype}'. AutoScale graph transformation may fail because of that." + ) return values.data, values.scale diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index f0665f6..808fa19 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -237,9 +237,11 @@ def test__autoscale_config__default_values(self): cfg = get_autoscale_config() assert isinstance(cfg, AutoScaleConfig) assert cfg.rounding_mode == Pow2RoundMode.DOWN + assert cfg.scale_dtype is None def test__autoscale_config__context_manager(self): - with AutoScaleConfig(rounding_mode=Pow2RoundMode.NONE): + with AutoScaleConfig(rounding_mode=Pow2RoundMode.NONE, scale_dtype=np.float32): cfg = get_autoscale_config() assert isinstance(cfg, AutoScaleConfig) assert cfg.rounding_mode == Pow2RoundMode.NONE + assert cfg.scale_dtype == np.float32 diff --git a/tests/lax/test_base_scaling_primitives.py b/tests/lax/test_base_scaling_primitives.py index af8e401..8a017f4 100644 --- a/tests/lax/test_base_scaling_primitives.py +++ b/tests/lax/test_base_scaling_primitives.py @@ -5,7 +5,7 @@ import numpy.testing as npt from absl.testing import parameterized -from jax_scaled_arithmetics.core import Array, ScaledArray, autoscale, scaled_array +from jax_scaled_arithmetics.core import Array, AutoScaleConfig, ScaledArray, autoscale, scaled_array from jax_scaled_arithmetics.lax.base_scaling_primitives import ( get_data_scale, rebalance, @@ -146,13 +146,17 @@ class GetDataScalePrimitiveTests(chex.TestCase): @chex.variants(with_jit=True, without_jit=True) def test__get_data_scale_primitive__proper_result_without_autoscale(self): def fn(arr): - return get_data_scale(arr) + # Set a default scale dtype. + with AutoScaleConfig(scale_dtype=np.float32): + return get_data_scale(arr) fn = self.variant(fn) arr = jnp.array([2, 3], dtype=np.float16) data, scale = fn(arr) + assert data.dtype == np.float16 + assert scale.dtype == np.float32 npt.assert_array_equal(data, arr) - npt.assert_equal(scale, np.array(1, arr.dtype)) + npt.assert_equal(scale, np.array(1, np.float32)) @chex.variants(with_jit=True, without_jit=True) def test__get_data_scale_primitive__proper_result_with_autoscale(self):