Skip to content

Commit

Permalink
Adding is_static_zero helper function, checking if a ScaledArray is…
Browse files Browse the repository at this point in the history
… statically zero. (#47)
  • Loading branch information
balancap authored Dec 4, 2023
1 parent f4e0331 commit 5ae6190
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
as_scaled_array,
asarray,
is_scaled_leaf,
is_static_zero,
scaled_array,
)
from .interpreters import ( # noqa: F401
Expand Down
23 changes: 23 additions & 0 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,26 @@ def asarray(val: Any, dtype: DTypeLike = None) -> GenericArray:
dtype: Optional dtype of the final array.
"""
return jax.tree_map(lambda x: asarray_base(x, dtype), val, is_leaf=is_scaled_leaf)


def is_numpy_scalar_or_array(val):
return isinstance(val, np.ndarray) or np.isscalar(val)


def is_static_zero(val: Union[Array, ScaledArray]) -> Array:
"""Is a scaled array a static zero value (i.e. zero during JAX tracing as well)?
Returns a boolean Numpy array of the shape of the input.
"""
if is_numpy_scalar_or_array(val):
return np.equal(val, 0)
if isinstance(val, ScaledArray):
data_mask = (
np.equal(val.data, 0) if is_numpy_scalar_or_array(val.data) else np.zeros(val.data.shape, dtype=np.bool_)
)
scale_mask = (
np.equal(val.scale, 0) if is_numpy_scalar_or_array(val.scale) else np.zeros(val.scale.shape, dtype=np.bool_)
)
return np.logical_or(data_mask, scale_mask)
# By default: can't decide.
return np.zeros(val.shape, dtype=np.bool_)
26 changes: 25 additions & 1 deletion tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
from absl.testing import parameterized
from jax.core import ShapedArray

from jax_scaled_arithmetics.core import Array, ScaledArray, as_scaled_array, asarray, is_scaled_leaf, scaled_array
from jax_scaled_arithmetics.core import (
Array,
ScaledArray,
as_scaled_array,
asarray,
is_scaled_leaf,
is_static_zero,
scaled_array,
)


class ScaledArrayDataclassTests(chex.TestCase):
Expand Down Expand Up @@ -158,3 +166,19 @@ def test__asarray__complex_pytree(self):
assert all([isinstance(v, Array) for v in output.values()])
npt.assert_array_almost_equal(output["x"], input["x"])
npt.assert_array_almost_equal(output["y"], input["y"])

@parameterized.parameters(
{"val": 0, "result": True},
{"val": 0.0, "result": True},
{"val": np.int32(0), "result": True},
{"val": np.float16(0), "result": True},
{"val": np.array([1, 2]), "result": False},
{"val": np.array([0, 0.0]), "result": True},
{"val": jnp.array([0, 0.0]), "result": False},
{"val": ScaledArray(np.array([0, 0.0]), jnp.array(2.0)), "result": True},
{"val": ScaledArray(jnp.array([3, 4.0]), np.array(0.0)), "result": True},
{"val": ScaledArray(jnp.array([3, 4.0]), jnp.array(0.0)), "result": False},
)
def test__is_static_zero__proper_all_result(self, val, result):
all_zero = np.all(is_static_zero(val))
assert all_zero == result

0 comments on commit 5ae6190

Please sign in to comment.