diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scaled_arithmetics/core/__init__.py index c42c7ea..0200779 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scaled_arithmetics/core/__init__.py @@ -1,5 +1,13 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from .datatype import DTypeLike, ScaledArray, Shape, asarray, is_scaled_leaf, scaled_array # noqa: F401 +from .datatype import ( # noqa: F401 + DTypeLike, + ScaledArray, + Shape, + as_scaled_array, + asarray, + is_scaled_leaf, + scaled_array, +) from .interpreters import ( # noqa: F401 ScaledPrimitiveType, autoscale, diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index 038ef44..b3768d6 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -1,6 +1,6 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from dataclasses import dataclass -from typing import Any, Union +from typing import Any, Optional, Union import jax import jax.numpy as jnp @@ -87,6 +87,23 @@ def aval(self) -> ShapedArray: return ShapedArray(self.data.shape, self.data.dtype) +def is_scaled_leaf(val: Any) -> bool: + """Is input a JAX PyTree (scaled) leaf, including ScaledArray. + + This function is useful for JAX PyTree handling where the user wants + to keep the ScaledArray datastructures (i.e. not flattened as a pair of arrays). + """ + # TODO: check Numpy scalars as well? + return np.isscalar(val) or isinstance(val, (jax.Array, np.ndarray, ScaledArray)) + + +def scaled_array_base(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npapi: Any = jnp) -> ScaledArray: + """ScaledArray (helper) base factory method, similar to `(j)np.array`.""" + data = npapi.asarray(data, dtype=dtype) + scale = npapi.asarray(scale) + return ScaledArray(data, scale) + + def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npapi: Any = jnp) -> ScaledArray: """ScaledArray (helper) factory method, similar to `(j)np.array`. @@ -98,17 +115,35 @@ def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npa Returns: Scaled array instance. """ - data = npapi.asarray(data, dtype=dtype) - scale = npapi.asarray(scale) - return ScaledArray(data, scale) + return scaled_array_base(data, scale, dtype, npapi) -def asarray(val: Any, dtype: DTypeLike = None) -> GenericArray: - """Convert back to a common JAX/Numpy array. +def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> ScaledArray: + """ScaledArray (helper) base factory method, similar to `(j)np.array`.""" + scale = np.array(1, dtype=val.dtype) if scale is None else scale + if isinstance(val, ScaledArray): + return val + elif isinstance(val, (np.ndarray, jax.Array)): + return ScaledArray(val, scale) + return scaled_array_base(val, scale) + + +def as_scaled_array(val: Any, scale: Optional[ArrayLike] = None) -> ScaledArray: + """ScaledArray (helper) factory method, similar to `(j)np.array`. + + Compatible with JAX PyTree. Args: - dtype: Optional dtype of the final array. + val: Main data/values or existing ScaledArray. + scale: Optional scale to use when (potentially) converting. + Returns: + Scaled array instance. """ + return jax.tree_map(lambda x: as_scaled_array_base(x, None), val, is_leaf=is_scaled_leaf) + + +def asarray_base(val: Any, dtype: DTypeLike = None) -> GenericArray: + """Convert back to a common JAX/Numpy array, base function.""" if isinstance(val, ScaledArray): return val.to_array(dtype=dtype) elif isinstance(val, (jax.Array, np.ndarray)): @@ -119,11 +154,12 @@ def asarray(val: Any, dtype: DTypeLike = None) -> GenericArray: return np.asarray(val, dtype=dtype) -def is_scaled_leaf(val: Any) -> bool: - """Is input a JAX PyTree (scaled) leaf, including ScaledArray. +def asarray(val: Any, dtype: DTypeLike = None) -> GenericArray: + """Convert back to a common JAX/Numpy array. - This function is useful for JAX PyTree handling where the user wants - to keep the ScaledArray datastructures (i.e. not flattened as a pair of arrays). + Compatible with JAX PyTree. + + Args: + dtype: Optional dtype of the final array. """ - # TODO: check scalars as well? - return isinstance(val, (jax.Array, np.ndarray, ScaledArray, int, float)) + return jax.tree_map(lambda x: asarray_base(x, dtype), val, is_leaf=is_scaled_leaf) diff --git a/tests/core/test_datatype.py b/tests/core/test_datatype.py index 70eafb8..d11a34b 100644 --- a/tests/core/test_datatype.py +++ b/tests/core/test_datatype.py @@ -1,12 +1,13 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. import chex +import jax import jax.numpy as jnp import numpy as np import numpy.testing as npt from absl.testing import parameterized from jax.core import ShapedArray -from jax_scaled_arithmetics.core import ScaledArray, asarray, is_scaled_leaf, scaled_array +from jax_scaled_arithmetics.core import ScaledArray, as_scaled_array, asarray, is_scaled_leaf, scaled_array class ScaledArrayDataclassTests(chex.TestCase): @@ -78,11 +79,37 @@ def test__scaled_array__numpy_array_interface(self, npapi): def test__is_scaled_leaf__consistent_with_jax(self): assert is_scaled_leaf(8) assert is_scaled_leaf(2.0) + assert is_scaled_leaf(np.int32(2)) + assert is_scaled_leaf(np.float32(2)) assert is_scaled_leaf(np.array(3)) assert is_scaled_leaf(np.array([3])) assert is_scaled_leaf(jnp.array([3])) assert is_scaled_leaf(scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float16)) + @parameterized.parameters( + {"data": np.array(2)}, + {"data": np.array([1, 2])}, + {"data": jnp.array([1, 2])}, + ) + def test__as_scaled_array__unchanged_dtype(self, data): + output = as_scaled_array(data) + assert isinstance(output, ScaledArray) + assert isinstance(output.data, type(data)) + assert output.dtype == data.dtype + npt.assert_array_almost_equal(output, data) + npt.assert_array_equal(output.scale, np.array(1, dtype=data.dtype)) + # unvariant when calling a second time. + assert as_scaled_array(output) is output + + def test__as_scaled_array__complex_pytree(self): + input = {"x": jnp.array([1, 2]), "y": as_scaled_array(jnp.array([1, 2]))} + output = as_scaled_array(input) + assert isinstance(output, dict) + assert len(output) == 2 + assert isinstance(output["x"], ScaledArray) + npt.assert_array_equal(output["x"], input["x"]) + assert output["y"] is input["y"] + @parameterized.parameters( {"data": np.array(2)}, {"data": np.array([1, 2])}, @@ -102,3 +129,12 @@ def test__asarray__changed_dtype(self, data): output = asarray(data, dtype=np.float16) assert output.dtype == np.float16 npt.assert_array_almost_equal(output, data) + + def test__asarray__complex_pytree(self): + input = {"x": jnp.array([1.0, 2]), "y": scaled_array(jnp.array([3, 4.0]), jnp.array(0.5))} + output = asarray(input) + assert isinstance(output, dict) + assert len(output) == 2 + assert all([isinstance(v, jax.Array) for v in output.values()]) + npt.assert_array_almost_equal(output["x"], input["x"]) + npt.assert_array_almost_equal(output["y"], input["y"])