Skip to content

Commit

Permalink
as_scaled_array and asarray factory methods compatible with JAX P…
Browse files Browse the repository at this point in the history
…yTrees.

Similarly to `jax.device_put/get`, allowing easy handling of complex datastructures
like model state or batch.
  • Loading branch information
balancap committed Nov 27, 2023
1 parent 1ec9dd6 commit d13fa71
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 15 deletions.
10 changes: 9 additions & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .datatype import DTypeLike, ScaledArray, Shape, asarray, is_scaled_leaf, scaled_array # noqa: F401
from .datatype import ( # noqa: F401
DTypeLike,
ScaledArray,
Shape,
as_scaled_array,
asarray,
is_scaled_leaf,
scaled_array,
)
from .interpreters import ( # noqa: F401
ScaledPrimitiveType,
autoscale,
Expand Down
62 changes: 49 additions & 13 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from dataclasses import dataclass
from typing import Any, Union
from typing import Any, Optional, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -87,6 +87,23 @@ def aval(self) -> ShapedArray:
return ShapedArray(self.data.shape, self.data.dtype)


def is_scaled_leaf(val: Any) -> bool:
"""Is input a JAX PyTree (scaled) leaf, including ScaledArray.
This function is useful for JAX PyTree handling where the user wants
to keep the ScaledArray datastructures (i.e. not flattened as a pair of arrays).
"""
# TODO: check Numpy scalars as well?
return np.isscalar(val) or isinstance(val, (jax.Array, np.ndarray, ScaledArray))


def scaled_array_base(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npapi: Any = jnp) -> ScaledArray:
"""ScaledArray (helper) base factory method, similar to `(j)np.array`."""
data = npapi.asarray(data, dtype=dtype)
scale = npapi.asarray(scale)
return ScaledArray(data, scale)


def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npapi: Any = jnp) -> ScaledArray:
"""ScaledArray (helper) factory method, similar to `(j)np.array`.
Expand All @@ -98,17 +115,35 @@ def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npa
Returns:
Scaled array instance.
"""
data = npapi.asarray(data, dtype=dtype)
scale = npapi.asarray(scale)
return ScaledArray(data, scale)
return scaled_array_base(data, scale, dtype, npapi)


def asarray(val: Any, dtype: DTypeLike = None) -> GenericArray:
"""Convert back to a common JAX/Numpy array.
def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> ScaledArray:
"""ScaledArray (helper) base factory method, similar to `(j)np.array`."""
scale = np.array(1, dtype=val.dtype) if scale is None else scale
if isinstance(val, ScaledArray):
return val
elif isinstance(val, (np.ndarray, jax.Array)):
return ScaledArray(val, scale)
return scaled_array_base(val, scale)


def as_scaled_array(val: Any, scale: Optional[ArrayLike] = None) -> ScaledArray:
"""ScaledArray (helper) factory method, similar to `(j)np.array`.
Compatible with JAX PyTree.
Args:
dtype: Optional dtype of the final array.
val: Main data/values or existing ScaledArray.
scale: Optional scale to use when (potentially) converting.
Returns:
Scaled array instance.
"""
return jax.tree_map(lambda x: as_scaled_array_base(x, None), val, is_leaf=is_scaled_leaf)


def asarray_base(val: Any, dtype: DTypeLike = None) -> GenericArray:
"""Convert back to a common JAX/Numpy array, base function."""
if isinstance(val, ScaledArray):
return val.to_array(dtype=dtype)
elif isinstance(val, (jax.Array, np.ndarray)):
Expand All @@ -119,11 +154,12 @@ def asarray(val: Any, dtype: DTypeLike = None) -> GenericArray:
return np.asarray(val, dtype=dtype)


def is_scaled_leaf(val: Any) -> bool:
"""Is input a JAX PyTree (scaled) leaf, including ScaledArray.
def asarray(val: Any, dtype: DTypeLike = None) -> GenericArray:
"""Convert back to a common JAX/Numpy array.
This function is useful for JAX PyTree handling where the user wants
to keep the ScaledArray datastructures (i.e. not flattened as a pair of arrays).
Compatible with JAX PyTree.
Args:
dtype: Optional dtype of the final array.
"""
# TODO: check scalars as well?
return isinstance(val, (jax.Array, np.ndarray, ScaledArray, int, float))
return jax.tree_map(lambda x: asarray_base(x, dtype), val, is_leaf=is_scaled_leaf)
38 changes: 37 additions & 1 deletion tests/core/test_datatype.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import chex
import jax
import jax.numpy as jnp
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized
from jax.core import ShapedArray

from jax_scaled_arithmetics.core import ScaledArray, asarray, is_scaled_leaf, scaled_array
from jax_scaled_arithmetics.core import ScaledArray, as_scaled_array, asarray, is_scaled_leaf, scaled_array


class ScaledArrayDataclassTests(chex.TestCase):
Expand Down Expand Up @@ -78,11 +79,37 @@ def test__scaled_array__numpy_array_interface(self, npapi):
def test__is_scaled_leaf__consistent_with_jax(self):
assert is_scaled_leaf(8)
assert is_scaled_leaf(2.0)
assert is_scaled_leaf(np.int32(2))
assert is_scaled_leaf(np.float32(2))
assert is_scaled_leaf(np.array(3))
assert is_scaled_leaf(np.array([3]))
assert is_scaled_leaf(jnp.array([3]))
assert is_scaled_leaf(scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float16))

@parameterized.parameters(
{"data": np.array(2)},
{"data": np.array([1, 2])},
{"data": jnp.array([1, 2])},
)
def test__as_scaled_array__unchanged_dtype(self, data):
output = as_scaled_array(data)
assert isinstance(output, ScaledArray)
assert isinstance(output.data, type(data))
assert output.dtype == data.dtype
npt.assert_array_almost_equal(output, data)
npt.assert_array_equal(output.scale, np.array(1, dtype=data.dtype))
# unvariant when calling a second time.
assert as_scaled_array(output) is output

def test__as_scaled_array__complex_pytree(self):
input = {"x": jnp.array([1, 2]), "y": as_scaled_array(jnp.array([1, 2]))}
output = as_scaled_array(input)
assert isinstance(output, dict)
assert len(output) == 2
assert isinstance(output["x"], ScaledArray)
npt.assert_array_equal(output["x"], input["x"])
assert output["y"] is input["y"]

@parameterized.parameters(
{"data": np.array(2)},
{"data": np.array([1, 2])},
Expand All @@ -102,3 +129,12 @@ def test__asarray__changed_dtype(self, data):
output = asarray(data, dtype=np.float16)
assert output.dtype == np.float16
npt.assert_array_almost_equal(output, data)

def test__asarray__complex_pytree(self):
input = {"x": jnp.array([1.0, 2]), "y": scaled_array(jnp.array([3, 4.0]), jnp.array(0.5))}
output = asarray(input)
assert isinstance(output, dict)
assert len(output) == 2
assert all([isinstance(v, jax.Array) for v in output.values()])
npt.assert_array_almost_equal(output["x"], input["x"])
npt.assert_array_almost_equal(output["y"], input["y"])

0 comments on commit d13fa71

Please sign in to comment.