diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scaled_arithmetics/core/__init__.py index 4d66c5b..f56126f 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scaled_arithmetics/core/__init__.py @@ -1,3 +1,9 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from .datatype import DTypeLike, ScaledArray, Shape, is_scaled_leaf, scaled_array # noqa: F401 -from .interpreters import ScaledPrimitiveType, autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401 +from .interpreters import ( # noqa: F401 + ScaledPrimitiveType, + autoscale, + find_registered_scaled_op, + register_scaled_lax_op, + register_scaled_op, +) diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index f157289..7d1ffb0 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -1,7 +1,7 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from enum import IntEnum from functools import wraps -from typing import Any, Dict +from typing import Any, Dict, Tuple import jax import numpy as np @@ -10,7 +10,24 @@ from .datatype import NDArray, ScaledArray -_scaled_ops_registry: Dict[core.Primitive, Any] = {} + +class ScaledPrimitiveType(IntEnum): + """Scale (JAX) primitive type. + + This enum described the behaviour when `autoscale` is + tracing the graph. + + FORWARD: Forwarding scaling => only used if scaled inputs. + Default behaviour. + ALWAYS_SCALE: Always use scaled version. + """ + + NEVER = 0 + FORWARD = 1 + ALWAYS_SCALE = 2 + + +_scaled_ops_registry: Dict[core.Primitive, Tuple[Any, ScaledPrimitiveType]] = {} def _get_lax_prim(scaled_func: Any) -> core.Primitive: @@ -43,22 +60,6 @@ def promote_scalar_to_scaled_array(val: Any) -> ScaledArray: return ScaledArray(data=onedata, scale=val) -class ScaledPrimitiveType(IntEnum): - """Scale (JAX) primitive type. - - This enum described the behaviour when `autoscale` is - tracing the graph. - - FORWARD: Forwarding scaling => only used if scaled inputs. - Default behaviour. - ALWAYS_SCALE: Always use scaled version. - """ - - NEVER = 0 - FORWARD = 1 - ALWAYS_SCALE = 2 - - def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray: """Get the ScaledArray corresponding to a Numpy constant. @@ -102,6 +103,16 @@ def register_scaled_lax_op(scaled_func): return scaled_func +def find_registered_scaled_op(prim: core.Primitive) -> Tuple[Any, ScaledPrimitiveType]: + """Find a registered JAX scaled operation/translation. Returns (None, None) if + the primitive does not have a scaled translation registered. + + Args: + prim: JAX primitive. + """ + return _scaled_ops_registry.get(prim, (None, ScaledPrimitiveType.NEVER)) + + def autoscale(fun): """`autoscale` JAX graph transformation. diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index f54905d..5bbd252 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -107,3 +107,41 @@ def scaled_dot_general( / contracting_rescale ) return ScaledArray(output_data, output_scale) + + +@core.register_scaled_lax_op +def scaled_reduce_sum(val: ScaledArray, axes: Tuple[int]) -> ScaledArray: + assert isinstance(val, ScaledArray) + shape = val.shape + axes_size = np.array([shape[idx] for idx in axes]) + # Rescale data component following reduction axes. + axes_rescale = np.sqrt(np.prod(axes_size)) + data = lax.reduce_sum_p.bind(val.data, axes=axes) / axes_rescale + outscale = val.scale * axes_rescale + return ScaledArray(data, outscale) + + +@core.register_scaled_lax_op +def scaled_reduce_prod(val: ScaledArray, axes: Tuple[int]) -> ScaledArray: + assert isinstance(val, ScaledArray) + shape = val.shape + data = lax.reduce_prod_p.bind(val.data, axes=axes) + axes_size = np.prod(np.array([shape[idx] for idx in axes])) + scale = lax.integer_pow(val.scale, axes_size) + return ScaledArray(data, scale) + + +@core.register_scaled_lax_op +def scaled_reduce_max(val: ScaledArray, axes: Tuple[int]) -> ScaledArray: + assert isinstance(val, ScaledArray) + data = lax.reduce_max_p.bind(val.data, axes=axes) + # unchanged scaling. + return ScaledArray(data, val.scale) + + +@core.register_scaled_lax_op +def scaled_reduce_min(val: ScaledArray, axes: Tuple[int]) -> ScaledArray: + assert isinstance(val, ScaledArray) + data = lax.reduce_min_p.bind(val.data, axes=axes) + # unchanged scaling. + return ScaledArray(data, val.scale) diff --git a/tests/lax/test_scaled_ops.py b/tests/lax/test_scaled_ops.py index f733f05..799f545 100644 --- a/tests/lax/test_scaled_ops.py +++ b/tests/lax/test_scaled_ops.py @@ -2,8 +2,10 @@ import chex import numpy as np import numpy.testing as npt +from absl.testing import parameterized +from jax import lax -from jax_scaled_arithmetics.core import ScaledArray, scaled_array +from jax_scaled_arithmetics.core import ScaledArray, find_registered_scaled_op, scaled_array from jax_scaled_arithmetics.lax import ( scaled_add, scaled_broadcast_in_dim, @@ -93,3 +95,28 @@ def test__scaled_dot_general__proper_scaling(self): assert out.dtype == lhs.dtype npt.assert_almost_equal(out.scale, lhs.scale * rhs.scale * np.sqrt(5)) npt.assert_array_almost_equal(out, np.asarray(lhs) @ np.asarray(rhs)) + + +class ScaledTranslationReducePrimitivesTests(chex.TestCase): + def setUp(self): + super().setUp() + # Use random state for reproducibility! + self.rs = np.random.RandomState(42) + + @parameterized.parameters( + {"reduce_prim": lax.reduce_sum_p, "expected_scale": 2 * np.sqrt(5)}, + {"reduce_prim": lax.reduce_prod_p, "expected_scale": 2**5}, + {"reduce_prim": lax.reduce_min_p, "expected_scale": 2}, + {"reduce_prim": lax.reduce_max_p, "expected_scale": 2}, + ) + def test__scaled_reduce__single_axis__proper_scaling(self, reduce_prim, expected_scale): + axes = (0,) + val = scaled_array(self.rs.rand(5), 2.0, dtype=np.float32) + scaled_reduce_op, _ = find_registered_scaled_op(reduce_prim) + out = scaled_reduce_op(val, axes=axes) + + assert isinstance(out, ScaledArray) + assert out.shape == () + assert out.dtype == val.dtype + npt.assert_almost_equal(out.scale, expected_scale) + npt.assert_array_almost_equal(out, reduce_prim.bind(np.asarray(val), axes=axes))