diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 91ef749..6e46c13 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -113,7 +113,9 @@ def to_scaled_array(val): # Primitive is supported by `autoscale`? if eqn.primitive not in _scaled_ops_registry: - raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet") + raise NotImplementedError( + f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet." + ) outvals = _scaled_ops_registry[eqn.primitive](*invals, **eqn.params) if not eqn.primitive.multiple_results: outvals = [outvals] diff --git a/jax_scaled_arithmetics/lax/__init__.py b/jax_scaled_arithmetics/lax/__init__.py index 65b52cc..4d7e612 100644 --- a/jax_scaled_arithmetics/lax/__init__.py +++ b/jax_scaled_arithmetics/lax/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from .base_scaling_primitives import set_scaling, set_scaling_p # noqa: F401 from .scaled_ops import * # noqa: F401, F403 diff --git a/jax_scaled_arithmetics/lax/base_scaling_primitives.py b/jax_scaled_arithmetics/lax/base_scaling_primitives.py new file mode 100644 index 0000000..5d29715 --- /dev/null +++ b/jax_scaled_arithmetics/lax/base_scaling_primitives.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from typing import Sequence, Union + +import jax +from jax import core +from jax.interpreters import mlir +from jax.interpreters.mlir import LoweringRuleContext, ir + +from jax_scaled_arithmetics.core import ScaledArray, register_scaled_op + +set_scaling_p = core.Primitive("set_scaling_p") +"""`set_scaling` JAX primitive. + +In standard JAX, this is just an identity operation, ignoring the `scale` +input, just returning unchanged the `data` component. + +In JAX Scaled Arithmetics mode, it will rebalance the data term to +return a ScaledArray semantically equivalent. +""" + + +def set_scaling(data: jax.Array, scale: jax.Array) -> jax.Array: + return set_scaling_p.bind(data, scale) + + +def set_scaling_impl(data: jax.Array, scale: jax.Array) -> jax.Array: + return data + + +def set_scaling_abstract_eval(data: core.ShapedArray, scale: core.ShapedArray) -> core.ShapedArray: + return data + + +def set_default_mlir_lowering( + ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]] +) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]: + # Just forwarding `data` term, ignoring the scaling. + return (args[0],) + + +def scaled_set_scaling(val: ScaledArray, scale: ScaledArray) -> ScaledArray: + """Scaled `set_scaling` implementation: rebalancing the data using the new scale value.""" + assert isinstance(val, ScaledArray) + assert isinstance(scale, ScaledArray) + assert scale.shape == () + + scale_value = scale.to_array() + # Rebalancing data term. + data = val.data * (val.scale / scale_value) + return ScaledArray(data, scale_value) + + +set_scaling_p.multiple_results = False +# Register the abstract evaluation with JAX +set_scaling_p.def_abstract_eval(set_scaling_abstract_eval) +# Register the primal implementation with JAX +set_scaling_p.def_impl(set_scaling_impl) +# Register MLIR lowering. +mlir.register_lowering(set_scaling_p, set_default_mlir_lowering) +register_scaled_op(set_scaling_p, scaled_set_scaling) + + +stop_scaling_p = core.Primitive("stop_scaling_p") +"""`stop_scaling` JAX primitive. + +Similar in principle to `jax.lax.stop_gradient` +""" diff --git a/tests/lax/test_base_scaling_primitives.py b/tests/lax/test_base_scaling_primitives.py new file mode 100644 index 0000000..7b7e697 --- /dev/null +++ b/tests/lax/test_base_scaling_primitives.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import chex +import jax.numpy as jnp +import numpy as np +import numpy.testing as npt + +from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array +from jax_scaled_arithmetics.lax import set_scaling + + +class SetScalingPrimitiveTests(chex.TestCase): + @chex.variants(with_jit=True, without_jit=True) + def test__set_scaling_primitive__proper_result_without_autoscale(self): + def fn(arr, scale): + return set_scaling(arr, scale) + + fn = self.variant(fn) + arr = jnp.array([2, 3], dtype=np.float32) + scale = jnp.array(4, dtype=np.float32) + out = fn(arr, scale) + npt.assert_array_equal(out, arr) + + @chex.variants(with_jit=True, without_jit=False) + def test__set_scaling_primitive__proper_result_with_autoscale(self): + def fn(arr, scale): + return set_scaling(arr, scale) + + fn = self.variant(autoscale(fn)) + arr = scaled_array([-1.0, 2.0], 1.0, dtype=np.float32) + # TODO: support scalar here! + scale = scaled_array(1.0, 4.0, dtype=np.float32) + out = fn(arr, scale) + # Unchanged output tensor! + assert isinstance(out, ScaledArray) + npt.assert_array_equal(out.scale, scale) + npt.assert_array_equal(out, arr)