From cd19da5462da6a0e1843d9adfe4d0087c768110b Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Tue, 21 Nov 2023 21:56:47 +0000 Subject: [PATCH] Default scaled translation of JAX primitives. Default scaled translation (i.e. unscale/primitive/rescale) for `log`, `exp` and `select`. This will unlock minimal operator coverage for MNIST/MLP experiments. --- jax_scaled_arithmetics/core/__init__.py | 2 +- jax_scaled_arithmetics/core/datatype.py | 16 +++++++ .../lax/base_scaling_primitives.py | 4 +- jax_scaled_arithmetics/lax/scaled_ops.py | 47 +++++++++++++++++++ tests/core/test_datatype.py | 22 ++++++++- tests/lax/test_scaled_ops.py | 29 ++++++++++++ 6 files changed, 116 insertions(+), 4 deletions(-) diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scaled_arithmetics/core/__init__.py index f56126f..713ea60 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scaled_arithmetics/core/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from .datatype import DTypeLike, ScaledArray, Shape, is_scaled_leaf, scaled_array # noqa: F401 +from .datatype import DTypeLike, ScaledArray, Shape, asarray, is_scaled_leaf, scaled_array # noqa: F401 from .interpreters import ( # noqa: F401 ScaledPrimitiveType, autoscale, diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index 08aaa4d..038ef44 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -103,6 +103,22 @@ def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npa return ScaledArray(data, scale) +def asarray(val: Any, dtype: DTypeLike = None) -> GenericArray: + """Convert back to a common JAX/Numpy array. + + Args: + dtype: Optional dtype of the final array. + """ + if isinstance(val, ScaledArray): + return val.to_array(dtype=dtype) + elif isinstance(val, (jax.Array, np.ndarray)): + if dtype is None: + return val + return val.astype(dtype=dtype) + # Convert to Numpy all other cases? + return np.asarray(val, dtype=dtype) + + def is_scaled_leaf(val: Any) -> bool: """Is input a JAX PyTree (scaled) leaf, including ScaledArray. diff --git a/jax_scaled_arithmetics/lax/base_scaling_primitives.py b/jax_scaled_arithmetics/lax/base_scaling_primitives.py index f769dba..a778c34 100644 --- a/jax_scaled_arithmetics/lax/base_scaling_primitives.py +++ b/jax_scaled_arithmetics/lax/base_scaling_primitives.py @@ -6,7 +6,7 @@ from jax.interpreters import mlir from jax.interpreters.mlir import LoweringRuleContext, ir -from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, ScaledPrimitiveType, register_scaled_op +from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, ScaledPrimitiveType, asarray, register_scaled_op set_scaling_p = core.Primitive("set_scaling_p") """`set_scaling` JAX primitive. @@ -43,7 +43,7 @@ def scaled_set_scaling(values: ScaledArray, scale: ScaledArray) -> ScaledArray: """Scaled `set_scaling` implementation: rebalancing the data using the new scale value.""" assert scale.shape == () # Automatic promotion should ensure we always get a scaled scalar here! - scale_value = scale.to_array() + scale_value = asarray(scale) if not isinstance(values, ScaledArray): # Simple case, with no pre-existing scale. return ScaledArray(values / scale_value, scale_value) diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index 1fd6f4d..f96c625 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -2,6 +2,7 @@ from typing import Any, Optional, Sequence, Tuple import jax +import jax.core import jax.numpy as jnp import numpy as np from jax import lax @@ -9,6 +10,14 @@ from jax_scaled_arithmetics import core from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, Shape +from .base_scaling_primitives import scaled_set_scaling + + +@core.register_scaled_lax_op +def scaled_stop_gradient(val: ScaledArray) -> ScaledArray: + # Stop gradients on both data and scale tensors. + return ScaledArray(lax.stop_gradient(val.data), lax.stop_gradient(val.scale)) + @core.register_scaled_lax_op def scaled_broadcast_in_dim(A: ScaledArray, shape: Shape, broadcast_dimensions: Sequence[int]) -> ScaledArray: @@ -195,3 +204,41 @@ def scaled_lt(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array: @core.register_scaled_lax_op def scaled_le(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array: return scaled_boolean_binary_op(lhs, rhs, lax.le_p) + + +################################################################## +# Default scaled ops implementation # +################################################################## +def scaled_op_default_translation( + prim: jax.core.Primitive, args: Sequence[ScaledArray], outscale: Optional[jax.Array] = None +) -> ScaledArray: + """Scaled op default translation of a JAX primitive: unscaling inputs + calling normal primitive. + + Args: + prim: JAX primitive + args: Input arguments. + outscale: Output scale to use. + """ + inputs = [core.asarray(v) for v in args] + print(inputs) + output = prim.bind(*inputs) + # Rescale output, if necessary. + if outscale is None: + return ScaledArray(output, np.array(1.0, dtype=output.dtype)) + output_scaled = scaled_set_scaling(output, outscale) + return output_scaled + + +@core.register_scaled_lax_op +def scaled_exp(val: ScaledArray) -> ScaledArray: + return scaled_op_default_translation(lax.exp_p, [val]) + + +@core.register_scaled_lax_op +def scaled_log(val: ScaledArray) -> ScaledArray: + return scaled_op_default_translation(lax.log_p, [val]) + + +@core.register_scaled_lax_op +def scaled_select_n(which: jax.Array, *cases: ScaledArray) -> ScaledArray: + return scaled_op_default_translation(lax.select_n_p, [which, *cases]) diff --git a/tests/core/test_datatype.py b/tests/core/test_datatype.py index dfb01f8..70eafb8 100644 --- a/tests/core/test_datatype.py +++ b/tests/core/test_datatype.py @@ -6,7 +6,7 @@ from absl.testing import parameterized from jax.core import ShapedArray -from jax_scaled_arithmetics.core import ScaledArray, is_scaled_leaf, scaled_array +from jax_scaled_arithmetics.core import ScaledArray, asarray, is_scaled_leaf, scaled_array class ScaledArrayDataclassTests(chex.TestCase): @@ -82,3 +82,23 @@ def test__is_scaled_leaf__consistent_with_jax(self): assert is_scaled_leaf(np.array([3])) assert is_scaled_leaf(jnp.array([3])) assert is_scaled_leaf(scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float16)) + + @parameterized.parameters( + {"data": np.array(2)}, + {"data": np.array([1, 2])}, + {"data": jnp.array([1, 2])}, + ) + def test__asarray__unchanged_dtype(self, data): + output = asarray(data) + assert output.dtype == data.dtype + npt.assert_array_almost_equal(output, data) + + @parameterized.parameters( + {"data": np.array([1, 2])}, + {"data": jnp.array([1, 2])}, + {"data": scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float32)}, + ) + def test__asarray__changed_dtype(self, data): + output = asarray(data, dtype=np.float16) + assert output.dtype == np.float16 + npt.assert_array_almost_equal(output, data) diff --git a/tests/lax/test_scaled_ops.py b/tests/lax/test_scaled_ops.py index 191c46f..bfb68e2 100644 --- a/tests/lax/test_scaled_ops.py +++ b/tests/lax/test_scaled_ops.py @@ -13,8 +13,11 @@ scaled_concatenate, scaled_convert_element_type, scaled_dot_general, + scaled_exp, scaled_is_finite, + scaled_log, scaled_mul, + scaled_select_n, scaled_slice, scaled_sub, scaled_transpose, @@ -98,6 +101,22 @@ def test__scaled_dot_general__proper_scaling(self): npt.assert_almost_equal(out.scale, lhs.scale * rhs.scale * np.sqrt(5)) npt.assert_array_almost_equal(out, np.asarray(lhs) @ np.asarray(rhs)) + def test__scaled_exp__proper_scaling(self): + val = scaled_array(self.rs.rand(3, 5), 2.0, dtype=np.float32) + out = scaled_exp(val) + assert isinstance(out, ScaledArray) + assert out.dtype == val.dtype + npt.assert_almost_equal(out.scale, 1) # FIXME! + npt.assert_array_almost_equal(out, np.exp(val)) + + def test__scaled_log__proper_scaling(self): + val = scaled_array(self.rs.rand(3, 5), 2.0, dtype=np.float32) + out = scaled_log(val) + assert isinstance(out, ScaledArray) + assert out.dtype == val.dtype + npt.assert_almost_equal(out.scale, 1) # FIXME! + npt.assert_array_almost_equal(out, np.log(val)) + class ScaledTranslationReducePrimitivesTests(chex.TestCase): def setUp(self): @@ -161,3 +180,13 @@ def test__scaled_boolean_binary_op__proper_result(self, bool_prim): assert out0.dtype == np.bool_ npt.assert_array_equal(out0, bool_prim.bind(lhs.to_array(), rhs.to_array())) npt.assert_array_equal(out1, bool_prim.bind(lhs.to_array(), lhs.to_array())) + + def test__scaled_select_n__proper_result(self): + mask = self.rs.rand(5) > 0.5 + lhs = scaled_array(self.rs.rand(5), 2.0, dtype=np.float32) + rhs = scaled_array(self.rs.rand(5), 3.0, dtype=np.float32) + out = scaled_select_n(mask, lhs, rhs) + assert isinstance(out, ScaledArray) + assert out.dtype == np.float32 + npt.assert_almost_equal(out.scale, 1) # FIXME! + npt.assert_array_equal(out, np.where(mask, rhs, lhs))