diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index 5bbd252..1fd6f4d 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -1,6 +1,7 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from typing import Any, Optional, Sequence, Tuple +import jax import jax.numpy as jnp import numpy as np from jax import lax @@ -145,3 +146,52 @@ def scaled_reduce_min(val: ScaledArray, axes: Tuple[int]) -> ScaledArray: data = lax.reduce_min_p.bind(val.data, axes=axes) # unchanged scaling. return ScaledArray(data, val.scale) + + +@core.register_scaled_lax_op +def scaled_is_finite(val: ScaledArray) -> jax.Array: + assert isinstance(val, ScaledArray) + if np.issubdtype(val.scale.dtype, np.integer): + # Integer scale case => only check the data component. + return lax.is_finite(val.data) + # Both data & scale need to be finite! + return lax.and_p.bind(lax.is_finite(val.data), lax.is_finite(val.scale)) + + +def scaled_boolean_binary_op(lhs: ScaledArray, rhs: ScaledArray, prim: jax.core.Primitive) -> jax.Array: + """Generic implementation of any boolean binary operation.""" + assert isinstance(lhs, ScaledArray) + assert isinstance(rhs, ScaledArray) + # FIXME: fix this absolute horror! + # TODO: use max scale + special case for scalars. + return prim.bind(lhs.to_array(dtype=np.float32), rhs.to_array(dtype=np.float32)) + + +@core.register_scaled_lax_op +def scaled_eq(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array: + return scaled_boolean_binary_op(lhs, rhs, lax.eq_p) + + +@core.register_scaled_lax_op +def scaled_ne(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array: + return scaled_boolean_binary_op(lhs, rhs, lax.ne_p) + + +@core.register_scaled_lax_op +def scaled_gt(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array: + return scaled_boolean_binary_op(lhs, rhs, lax.gt_p) + + +@core.register_scaled_lax_op +def scaled_ge(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array: + return scaled_boolean_binary_op(lhs, rhs, lax.ge_p) + + +@core.register_scaled_lax_op +def scaled_lt(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array: + return scaled_boolean_binary_op(lhs, rhs, lax.lt_p) + + +@core.register_scaled_lax_op +def scaled_le(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array: + return scaled_boolean_binary_op(lhs, rhs, lax.le_p) diff --git a/tests/lax/test_scaled_ops.py b/tests/lax/test_scaled_ops.py index 799f545..bb67fdc 100644 --- a/tests/lax/test_scaled_ops.py +++ b/tests/lax/test_scaled_ops.py @@ -1,5 +1,6 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. import chex +import jax import numpy as np import numpy.testing as npt from absl.testing import parameterized @@ -12,6 +13,7 @@ scaled_concatenate, scaled_convert_element_type, scaled_dot_general, + scaled_is_finite, scaled_mul, scaled_slice, scaled_sub, @@ -120,3 +122,23 @@ def test__scaled_reduce__single_axis__proper_scaling(self, reduce_prim, expected assert out.dtype == val.dtype npt.assert_almost_equal(out.scale, expected_scale) npt.assert_array_almost_equal(out, reduce_prim.bind(np.asarray(val), axes=axes)) + + +class ScaledTranslationBooleanPrimitivesTests(chex.TestCase): + def setUp(self): + super().setUp() + # Use random state for reproducibility! + self.rs = np.random.RandomState(42) + + @parameterized.parameters( + {"val": scaled_array([2, 3], 2.0, dtype=np.float32), "expected_out": [True, True]}, + # Supporting `int` scale as well. + {"val": scaled_array([2, np.inf], 2, dtype=np.float32), "expected_out": [True, False]}, + {"val": scaled_array([2, 3], np.nan, dtype=np.float32), "expected_out": [False, False]}, + {"val": scaled_array([np.nan, 3], 3.0, dtype=np.float32), "expected_out": [False, True]}, + ) + def test__scaled_is_finite__proper_result(self, val, expected_out): + out = scaled_is_finite(val) + assert isinstance(out, jax.Array) + assert out.dtype == np.bool_ + npt.assert_array_equal(out, expected_out)