From 8fdedddf662a2c92f465a1f1ff5e1fa83df3a637 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 9 Nov 2023 16:59:34 +0000 Subject: [PATCH] Fix ScaledArray aval and create factory method. * `ScaledArray.aval` returning JAX `ShapedArray`, like JAX API; * Factory method `scaled_array`, similar to `jnp.array`; The latter makes testing code simpler & clearer. --- jax_scaled_arithmetics/__init__.py | 2 +- jax_scaled_arithmetics/core/__init__.py | 2 +- jax_scaled_arithmetics/core/datatype.py | 25 +++++++- tests/core/test_datatype.py | 59 ++++++++++++------- .../test_interpreter.py | 32 ++++------ 5 files changed, 75 insertions(+), 45 deletions(-) rename tests/{interpreters => core}/test_interpreter.py (58%) diff --git a/jax_scaled_arithmetics/__init__.py b/jax_scaled_arithmetics/__init__.py index 5a42b7e..e0205f2 100644 --- a/jax_scaled_arithmetics/__init__.py +++ b/jax_scaled_arithmetics/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from . import lax from ._version import __version__ -from .core import ScaledArray, autoscale # noqa: F401 +from .core import ScaledArray, autoscale, scaled_array # noqa: F401 diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scaled_arithmetics/core/__init__.py index 6d6ff96..22692c4 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scaled_arithmetics/core/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from .datatype import ScaledArray # noqa: F401 +from .datatype import ScaledArray, scaled_array # noqa: F401 from .interpreters import autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401 diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index cfcc518..eb656df 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -3,10 +3,12 @@ from typing import Any, Union import jax +import jax.numpy as jnp import numpy as np from chex import Shape +from jax.core import ShapedArray from jax.tree_util import register_pytree_node_class -from numpy.typing import DTypeLike, NDArray +from numpy.typing import ArrayLike, DTypeLike, NDArray GenericArray = Union[jax.Array, np.ndarray] @@ -80,5 +82,22 @@ def __array__(self, dtype: DTypeLike = None) -> NDArray[Any]: return np.asarray(self.to_array(dtype)) @property - def aval(self): - return self.data * self.scale + def aval(self) -> ShapedArray: + """Abstract value of the scaled array, i.e. shape and dtype.""" + return ShapedArray(self.data.shape, self.data.dtype) + + +def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npapi: Any = jnp) -> ScaledArray: + """ScaledArray (helper) factory method, similar to `(j)np.array`. + + Args: + data: Main data/values. + scale: Scale tensor. + dtype: Optional dtype to use for the data. + npapi: Numpy API to use. + Returns: + Scaled array instance. + """ + data = npapi.asarray(data, dtype=dtype) + scale = npapi.asarray(scale) + return ScaledArray(data, scale) diff --git a/tests/core/test_datatype.py b/tests/core/test_datatype.py index fcdc1c8..f0e0ad0 100644 --- a/tests/core/test_datatype.py +++ b/tests/core/test_datatype.py @@ -4,54 +4,73 @@ import numpy as np import numpy.testing as npt from absl.testing import parameterized +from jax.core import ShapedArray -from jax_scaled_arithmetics.core import ScaledArray +from jax_scaled_arithmetics import ScaledArray, scaled_array class ScaledArrayDataclassTests(chex.TestCase): @parameterized.parameters( - {"npb": np}, - {"npb": jnp}, + {"npapi": np}, + {"npapi": jnp}, ) - def test__scaled_array__init__multi_numpy_backend(self, npb): - sarr = ScaledArray(data=npb.array([1.0, 2.0], dtype=np.float32), scale=npb.array(1)) - assert isinstance(sarr.data, npb.ndarray) - assert isinstance(sarr.scale, npb.ndarray) + def test__scaled_array__init__multi_numpy_backend(self, npapi): + sarr = ScaledArray(data=npapi.array([1.0, 2.0], dtype=np.float32), scale=npapi.array(1)) + assert isinstance(sarr.data, npapi.ndarray) + assert isinstance(sarr.scale, npapi.ndarray) assert sarr.scale.shape == () - def test__scaled_array__basic_properties(self): - sarr = ScaledArray(data=jnp.array([1.0, 2.0]), scale=jnp.array(1)) + @parameterized.parameters( + {"npapi": np}, + {"npapi": jnp}, + ) + def test__scaled_array__factory_method__multi_numpy_backend(self, npapi): + sarr = scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float16, npapi=npapi) + assert isinstance(sarr, ScaledArray) + assert isinstance(sarr.data, npapi.ndarray) + assert isinstance(sarr.scale, npapi.ndarray) + assert sarr.data.dtype == ShapedArray((2,), np.float16) + assert sarr.scale.shape == () + npt.assert_array_almost_equal(sarr, [3, 6]) + + @parameterized.parameters( + {"npapi": np}, + {"npapi": jnp}, + ) + def test__scaled_array__basic_properties(self, npapi): + sarr = ScaledArray(data=npapi.array([1.0, 2.0], dtype=np.float32), scale=npapi.array(1)) assert sarr.dtype == np.float32 assert sarr.shape == (2,) + assert sarr.aval == ShapedArray((2,), np.float32) @parameterized.parameters( - {"npb": np}, - {"npb": jnp}, + {"npapi": np}, + {"npapi": jnp}, ) - def test__scaled_array__to_array__multi_numpy_backend(self, npb): - sarr = ScaledArray(data=npb.array([1.0, 2.0], dtype=np.float16), scale=npb.array(3)) + def test__scaled_array__to_array__multi_numpy_backend(self, npapi): + sarr = scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float16, npapi=npapi) # No dtype specified. out = sarr.to_array() - assert isinstance(out, npb.ndarray) + assert isinstance(out, npapi.ndarray) assert out.dtype == sarr.dtype npt.assert_array_equal(out, sarr.data * sarr.scale) # Custom float dtype. out = sarr.to_array(dtype=np.float32) - assert isinstance(out, npb.ndarray) + assert isinstance(out, npapi.ndarray) assert out.dtype == np.float32 npt.assert_array_equal(out, sarr.data * sarr.scale) # Custom int dtype. out = sarr.to_array(dtype=np.int8) - assert isinstance(out, npb.ndarray) + assert isinstance(out, npapi.ndarray) assert out.dtype == np.int8 npt.assert_array_equal(out, sarr.data * sarr.scale) @parameterized.parameters( - {"npb": np}, - {"npb": jnp}, + {"npapi": np}, + {"npapi": jnp}, ) - def test__scaled_array__numpy_array_interface(self, npb): - sarr = ScaledArray(data=npb.array([1.0, 2.0], dtype=np.float32), scale=npb.array(3)) + def test__scaled_array__numpy_array_interface(self, npapi): + sarr = ScaledArray(data=npapi.array([1.0, 2.0], dtype=np.float32), scale=npapi.array(3)) out = np.asarray(sarr) assert isinstance(out, np.ndarray) npt.assert_array_equal(out, sarr.data * sarr.scale) diff --git a/tests/interpreters/test_interpreter.py b/tests/core/test_interpreter.py similarity index 58% rename from tests/interpreters/test_interpreter.py rename to tests/core/test_interpreter.py index 5bbb4ad..7249e3f 100644 --- a/tests/interpreters/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -3,8 +3,10 @@ import chex import jax import jax.numpy as jnp +import numpy as np +import numpy.testing as npt -from jax_scaled_arithmetics.core import ScaledArray, autoscale +from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array class AutoScaleInterpreterTests(chex.TestCase): @@ -14,20 +16,16 @@ def func(x): asfunc = autoscale(func) - scale = jnp.array(1.0) - inputs = jnp.array([1.0, 2.0]) - expected = jnp.array([1.0, 2.0]) - - scaled_inputs = ScaledArray(inputs, scale) + scaled_inputs = scaled_array([1.0, 2.0], 1, dtype=np.float32) scaled_outputs = asfunc(scaled_inputs) + expected = jnp.array([1.0, 2.0]) - assert jnp.allclose(scaled_outputs.aval, expected) - + assert isinstance(scaled_outputs, ScaledArray) + npt.assert_array_almost_equal(scaled_outputs, expected) jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr # Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray - - assert jaxpr.invars[0].aval.shape == inputs.shape + assert jaxpr.invars[0].aval.shape == scaled_inputs.shape assert jaxpr.invars[1].aval.shape == () assert jaxpr.outvars[0].aval.shape == expected.shape @@ -39,16 +37,10 @@ def func(x, y): asfunc = autoscale(func) - x_in = jnp.array([-2.0, 2.0]) - x_scale = jnp.array(0.5) - x = ScaledArray(x_in, x_scale) - - y_in = jnp.array([1.5, 1.5]) - y_scale = jnp.array(2.0) - y = ScaledArray(y_in, y_scale) - + x = scaled_array([-2.0, 2.0], 0.5, dtype=np.float32) + y = scaled_array([1.5, 1.5], 2, dtype=np.float32) expected = jnp.array([-3.0, 3.0]) out = asfunc(x, y) - - assert jnp.allclose(out.aval, expected) + assert isinstance(out, ScaledArray) + npt.assert_array_almost_equal(out, expected)