From 5ae6190ccba23a97214ddce24fce51ee964413b3 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 4 Dec 2023 12:14:38 +0000 Subject: [PATCH] Adding `is_static_zero` helper function, checking if a ScaledArray is statically zero. (#47) --- jax_scaled_arithmetics/core/__init__.py | 1 + jax_scaled_arithmetics/core/datatype.py | 23 ++++++++++++++++++++++ tests/core/test_datatype.py | 26 ++++++++++++++++++++++++- 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scaled_arithmetics/core/__init__.py index 1f9715a..ae7a56e 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scaled_arithmetics/core/__init__.py @@ -6,6 +6,7 @@ as_scaled_array, asarray, is_scaled_leaf, + is_static_zero, scaled_array, ) from .interpreters import ( # noqa: F401 diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index 8857bd1..63170d8 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -185,3 +185,26 @@ def asarray(val: Any, dtype: DTypeLike = None) -> GenericArray: dtype: Optional dtype of the final array. """ return jax.tree_map(lambda x: asarray_base(x, dtype), val, is_leaf=is_scaled_leaf) + + +def is_numpy_scalar_or_array(val): + return isinstance(val, np.ndarray) or np.isscalar(val) + + +def is_static_zero(val: Union[Array, ScaledArray]) -> Array: + """Is a scaled array a static zero value (i.e. zero during JAX tracing as well)? + + Returns a boolean Numpy array of the shape of the input. + """ + if is_numpy_scalar_or_array(val): + return np.equal(val, 0) + if isinstance(val, ScaledArray): + data_mask = ( + np.equal(val.data, 0) if is_numpy_scalar_or_array(val.data) else np.zeros(val.data.shape, dtype=np.bool_) + ) + scale_mask = ( + np.equal(val.scale, 0) if is_numpy_scalar_or_array(val.scale) else np.zeros(val.scale.shape, dtype=np.bool_) + ) + return np.logical_or(data_mask, scale_mask) + # By default: can't decide. + return np.zeros(val.shape, dtype=np.bool_) diff --git a/tests/core/test_datatype.py b/tests/core/test_datatype.py index d54ae2c..70fcc8d 100644 --- a/tests/core/test_datatype.py +++ b/tests/core/test_datatype.py @@ -6,7 +6,15 @@ from absl.testing import parameterized from jax.core import ShapedArray -from jax_scaled_arithmetics.core import Array, ScaledArray, as_scaled_array, asarray, is_scaled_leaf, scaled_array +from jax_scaled_arithmetics.core import ( + Array, + ScaledArray, + as_scaled_array, + asarray, + is_scaled_leaf, + is_static_zero, + scaled_array, +) class ScaledArrayDataclassTests(chex.TestCase): @@ -158,3 +166,19 @@ def test__asarray__complex_pytree(self): assert all([isinstance(v, Array) for v in output.values()]) npt.assert_array_almost_equal(output["x"], input["x"]) npt.assert_array_almost_equal(output["y"], input["y"]) + + @parameterized.parameters( + {"val": 0, "result": True}, + {"val": 0.0, "result": True}, + {"val": np.int32(0), "result": True}, + {"val": np.float16(0), "result": True}, + {"val": np.array([1, 2]), "result": False}, + {"val": np.array([0, 0.0]), "result": True}, + {"val": jnp.array([0, 0.0]), "result": False}, + {"val": ScaledArray(np.array([0, 0.0]), jnp.array(2.0)), "result": True}, + {"val": ScaledArray(jnp.array([3, 4.0]), np.array(0.0)), "result": True}, + {"val": ScaledArray(jnp.array([3, 4.0]), jnp.array(0.0)), "result": False}, + ) + def test__is_static_zero__proper_all_result(self, val, result): + all_zero = np.all(is_static_zero(val)) + assert all_zero == result