diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scaled_arithmetics/core/__init__.py index 37890b6..31e9aec 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scaled_arithmetics/core/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from .datatype import DTypeLike, ScaledArray, Shape, scaled_array # noqa: F401 +from .datatype import DTypeLike, ScaledArray, Shape, is_scaled_leaf, scaled_array # noqa: F401 from .interpreters import autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401 diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index eb656df..09cbf30 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -101,3 +101,13 @@ def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npa data = npapi.asarray(data, dtype=dtype) scale = npapi.asarray(scale) return ScaledArray(data, scale) + + +def is_scaled_leaf(val: Any) -> bool: + """Is input a JAX PyTree (scaled) leaf, including ScaledArray. + + This function is useful for JAX PyTree handling where the user wants + to keep the ScaledArray datastructures (i.e. not flattened as a pair of arrays). + """ + # TODO: check scalars as well? + return isinstance(val, (jax.Array, np.ndarray, ScaledArray, int, float)) diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 9810535..91ef749 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -4,14 +4,26 @@ from typing import Any, Dict import jax +import numpy as np from jax import core from jax._src.util import safe_map -from ..core import ScaledArray +from .datatype import NDArray, ScaledArray _scaled_ops_registry: Dict[core.Primitive, Any] = {} +def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray: + """Get the ScaledArray corresponding to a Numpy constant. + + Only supporting Numpy scalars at the moment. + """ + # TODO: generalized rules! + assert val.shape == () + assert np.issubdtype(val.dtype, np.floating) + return ScaledArray(data=np.array(1.0, dtype=val.dtype), scale=np.copy(val)) + + def register_scaled_op(prim: core.Primitive, scaled_func: Any) -> None: """Register the scaled translation of JAX primitive. @@ -57,11 +69,16 @@ def autoscale(fun): @wraps(fun) def wrapped(*args, **kwargs): aval_args = safe_map(lambda x: x.aval, args) - # get jaxpr of unscaled graph - closed_jaxpr = jax.make_jaxpr(fun)(*aval_args, **kwargs) - # convert to scaled graph - out = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args) - return out + # Get jaxpr of unscaled/normal graph. Getting output Pytree shape as well. + closed_jaxpr, outshape = jax.make_jaxpr(fun, return_shape=True)(*aval_args, **kwargs) + out_leaves, out_pytree = jax.tree_util.tree_flatten(outshape) + # Trace the graph & convert to scaled one. + outputs_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args) + # Reconstruct the output Pytree, with scaled arrays. + # NOTE: this step is also handling single vs multi outputs. + assert len(out_leaves) == len(outputs_flat) + output = jax.tree_util.tree_unflatten(out_pytree, outputs_flat) + return output return wrapped @@ -77,11 +94,24 @@ def read(var): def write(var, val): env[var] = val + def to_scaled_array(val): + if isinstance(val, ScaledArray): + return val + elif isinstance(val, np.ndarray): + return numpy_constant_scaled_array(val) + raise TypeError(f"Can not convert '{val}' to a scaled array.") + safe_map(write, jaxpr.invars, args) safe_map(write, jaxpr.constvars, consts) for eqn in jaxpr.eqns: invals = safe_map(read, eqn.invars) + # Make sure all inputs are scaled arrays + invals = list(map(to_scaled_array, invals)) + assert all([isinstance(v, ScaledArray) for v in invals]) + # TODO: handle `stop_scale` case? integer/boolean dtypes? + + # Primitive is supported by `autoscale`? if eqn.primitive not in _scaled_ops_registry: raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet") outvals = _scaled_ops_registry[eqn.primitive](*invals, **eqn.params) @@ -90,7 +120,4 @@ def write(var, val): safe_map(write, eqn.outvars, outvals) outvals = safe_map(read, jaxpr.outvars) - if len(outvals) == 1: - return outvals[0] - else: - return outvals + return outvals diff --git a/tests/core/test_datatype.py b/tests/core/test_datatype.py index f0e0ad0..dfb01f8 100644 --- a/tests/core/test_datatype.py +++ b/tests/core/test_datatype.py @@ -6,7 +6,7 @@ from absl.testing import parameterized from jax.core import ShapedArray -from jax_scaled_arithmetics import ScaledArray, scaled_array +from jax_scaled_arithmetics.core import ScaledArray, is_scaled_leaf, scaled_array class ScaledArrayDataclassTests(chex.TestCase): @@ -74,3 +74,11 @@ def test__scaled_array__numpy_array_interface(self, npapi): out = np.asarray(sarr) assert isinstance(out, np.ndarray) npt.assert_array_equal(out, sarr.data * sarr.scale) + + def test__is_scaled_leaf__consistent_with_jax(self): + assert is_scaled_leaf(8) + assert is_scaled_leaf(2.0) + assert is_scaled_leaf(np.array(3)) + assert is_scaled_leaf(np.array([3])) + assert is_scaled_leaf(jnp.array([3])) + assert is_scaled_leaf(scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float16)) diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index affe5ef..bb21b2d 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -2,11 +2,13 @@ import chex import jax -import jax.numpy as jnp + +# import jax.numpy as jnp import numpy as np import numpy.testing as npt +from absl.testing import parameterized -from jax_scaled_arithmetics.core import ScaledArray, autoscale, register_scaled_op, scaled_array +from jax_scaled_arithmetics.core import ScaledArray, autoscale, is_scaled_leaf, register_scaled_op, scaled_array class AutoScaleInterpreterTests(chex.TestCase): @@ -15,53 +17,64 @@ def test__register_scaled_op__error_if_already_registered(self): register_scaled_op(jax.lax.mul_p, lambda a, _: a) @chex.variants(with_jit=True, without_jit=True) - def test__scaled_identity_function(self): + def test__autoscale_interpreter__proper_signature(self): def func(x): - return x - - # Autoscale + (optional) jitting. - asfunc = self.variant(autoscale(func)) - - scaled_inputs = scaled_array([1.0, 2.0], 1, dtype=np.float32) - scaled_outputs = asfunc(scaled_inputs) - expected = jnp.array([1.0, 2.0]) - - assert isinstance(scaled_outputs, ScaledArray) - npt.assert_array_almost_equal(scaled_outputs, expected) - jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr + return x * 2 + scaled_func = self.variant(autoscale(func)) + scaled_input = scaled_array([1.0, 2.0], 3, dtype=np.float32) + jaxpr = jax.make_jaxpr(scaled_func)(scaled_input).jaxpr # Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray - assert jaxpr.invars[0].aval.shape == scaled_inputs.shape + assert jaxpr.invars[0].aval.shape == scaled_input.shape assert jaxpr.invars[1].aval.shape == () - assert jaxpr.outvars[0].aval.shape == expected.shape + assert jaxpr.outvars[0].aval.shape == scaled_input.shape assert jaxpr.outvars[1].aval.shape == () @chex.variants(with_jit=True, without_jit=True) - def test__scaled_mul__no_attributes(self): - def func(x, y): - return x * y - - # Autoscale + (optional) jitting. - asfunc = self.variant(autoscale(func)) - - x = scaled_array([-2.0, 2.0], 0.5, dtype=np.float32) - y = scaled_array([1.5, 1.5], 2, dtype=np.float32) - expected = jnp.array([-3.0, 3.0]) - - out = asfunc(x, y) - assert isinstance(out, ScaledArray) - npt.assert_array_almost_equal(out, expected) - - @chex.variants(with_jit=True, without_jit=True) - def test__scaled_convert_element_type__attributes_passing(self): - def func(x): - return jax.lax.convert_element_type(x, np.float16) - - # Autoscale + (optional) jitting. - asfunc = self.variant(autoscale(func)) - x = scaled_array([-4.0, 2.0], 0.5, dtype=np.float32) - out = asfunc(x) - assert isinstance(out, ScaledArray) - assert out.dtype == np.float16 - npt.assert_array_almost_equal(out, x) + @parameterized.parameters( + # Identity function! + {"fn": lambda x: x, "inputs": [scaled_array([1.0, 2.0], 3, dtype=np.float32)]}, + # Non-trivial output JAX pytree + {"fn": lambda x: {"x": (x,)}, "inputs": [scaled_array([1.0, 2.0], 3, dtype=np.float32)]}, + # Multi-inputs operation. + { + "fn": lambda x, y: x * y, + "inputs": [scaled_array([-2.0, 0.5], 0.5, dtype=np.float32), scaled_array([1.5, 1.5], 2, dtype=np.float32)], + }, + # Proper forwarding of attributes. + { + "fn": lambda x: jax.lax.convert_element_type(x, np.float16), + "inputs": [scaled_array([-4.0, 2.0], 0.5, dtype=np.float32)], + }, + # Proper constant scalar handling. + { + "fn": lambda x: x * 2, + "inputs": [scaled_array([[-2.0, 0.5]], 0.5, dtype=np.float32)], + }, + # TODO/FIXME: Proper constant Numpy array handling. + # { + # "fn": lambda x: x * np.array([2.0, 3.0], dtype=np.float32), + # "inputs": [scaled_array([[-2.0], [0.5]], 0.5, dtype=np.float32)], + # }, + ) + def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn, inputs): + # Autoscale function + (optional) jitting. + scaled_fn = self.variant(autoscale(fn)) + scaled_output = scaled_fn(*inputs) + # Normal JAX path, without scaled arrays. + raw_inputs = [np.asarray(v) for v in inputs] + expected_output = self.variant(fn)(*raw_inputs) + + # Do we re-construct properly the output type (i.e. handling Pytree properly)? + if not isinstance(expected_output, (np.ndarray, jax.Array)): + assert type(scaled_output) is type(expected_output) + + # Check each output in the flatten tree. + scaled_outputs_flat, _ = jax.tree_util.tree_flatten(scaled_output, is_leaf=is_scaled_leaf) + expected_outputs_flat, _ = jax.tree_util.tree_flatten(expected_output) + for scaled_out, exp_out in zip(scaled_outputs_flat, expected_outputs_flat): + assert isinstance(scaled_out, ScaledArray) + assert scaled_out.scale.shape == () + assert scaled_out.dtype == exp_out.dtype + npt.assert_array_almost_equal(scaled_out, exp_out, decimal=4)