diff --git a/jax_scaled_arithmetics/lax/scaled_ops_common.py b/jax_scaled_arithmetics/lax/scaled_ops_common.py index b58ffd9..ce6ca5b 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops_common.py +++ b/jax_scaled_arithmetics/lax/scaled_ops_common.py @@ -15,6 +15,7 @@ Shape, as_scaled_array, get_scale_dtype, + is_static_anyscale, is_static_zero, safe_div, ) @@ -223,10 +224,10 @@ def scaled_le(lhs: ScaledArray, rhs: ScaledArray) -> Array: def scaled_minmax(prim: jax.core.Primitive, lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray: """General min/max scaled translation: propagating the largest input scale.""" check_scalar_scales(lhs, rhs) - # Specific rule if lhs/rhs is zero => propagate the other term scale. - if np.all(is_static_zero(lhs)): + # Specific rule if lhs/rhs is zero or inf => propagate the other term scale. + if np.all(is_static_anyscale(lhs)): return ScaledArray(prim.bind(lhs.data, rhs.data), rhs.scale) - if np.all(is_static_zero(rhs)): + if np.all(is_static_anyscale(rhs)): return ScaledArray(prim.bind(lhs.data, rhs.data), lhs.scale) # Power-of-2 stable! diff --git a/tests/lax/test_scaled_ops_l2.py b/tests/lax/test_scaled_ops_l2.py index 85c2600..3ea6256 100644 --- a/tests/lax/test_scaled_ops_l2.py +++ b/tests/lax/test_scaled_ops_l2.py @@ -158,9 +158,8 @@ def test__scaled_addsub__not_overflowing_scale(self, prim): assert np.isfinite(z.scale) npt.assert_array_almost_equal(z, prim.bind(np.asarray(x, np.float32), np.asarray(y, np.float32)), decimal=6) - @parameterized.parameters( - {"prim": lax.max_p}, - {"prim": lax.min_p}, + @parameterized.product( + prim=[lax.min_p, lax.max_p], ) def test__scaled_minmax__static_zero_scale_propagation(self, prim): scaled_op, _ = find_registered_scaled_op(prim) @@ -172,6 +171,19 @@ def test__scaled_minmax__static_zero_scale_propagation(self, prim): # Keep the lhs scale. npt.assert_almost_equal(z.scale, 4.0) + @parameterized.product( + prim=[lax.min_p, lax.max_p], + ) + def test__scaled_minmax__static_inf_scale_propagation(self, prim): + scaled_op, _ = find_registered_scaled_op(prim) + x = scaled_array([-1.0, 2.0], 4.0, dtype=np.float32, npapi=np) + y = scaled_array([-np.inf, np.inf], np.inf, dtype=np.float32, npapi=np) + z = scaled_op(x, y) + assert isinstance(z, ScaledArray) + assert z.dtype == x.dtype + # Keep the lhs scale. + npt.assert_almost_equal(z.scale, 4.0) + def test__scaled_mul__proper_scaling(self): x = scaled_array([-2.0, 2.0], 3, dtype=np.float32) y = scaled_array([1.5, 1.5], 2, dtype=np.float32) diff --git a/tests/lax/test_scipy_integration.py b/tests/lax/test_scipy_integration.py index 926e6ec..58085d5 100644 --- a/tests/lax/test_scipy_integration.py +++ b/tests/lax/test_scipy_integration.py @@ -25,16 +25,12 @@ def fn(a): @chex.variants(with_jit=False, without_jit=True) @parameterized.parameters( {"dtype": np.float32}, - # {"dtype": np.float16}, + {"dtype": np.float16}, ) def test__scipy_logsumexp__accurate_scaled_op(self, dtype): - import jax from jax.scipy.special import logsumexp input_scaled = scaled_array(self.rs.rand(10), 4.0, dtype=dtype) - - print(jax.make_jaxpr(logsumexp)(input_scaled.data)) - # JAX `logsumexp` Jaxpr is a non-trivial graph! out_scaled = self.variant(autoscale(logsumexp))(input_scaled) out_expected = logsumexp(np.asarray(input_scaled)) @@ -42,5 +38,3 @@ def test__scipy_logsumexp__accurate_scaled_op(self, dtype): # Proper accuracy + keep the same scale. npt.assert_array_equal(out_scaled.scale, input_scaled.scale) npt.assert_array_almost_equal(out_scaled, out_expected, decimal=5) - - assert False