From 318bd41e833d2ee2d773607e946baea61497d848 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 14 Dec 2023 14:23:05 +0000 Subject: [PATCH] Additional scale promotion unit test coverage in binary ops. (#55) Help catching & solving a bug around NumPy handling of `x**2`, dependent on where `x` is a scalar or an array. --- jax_scaled_arithmetics/lax/scaled_ops.py | 4 +-- tests/core/test_interpreter.py | 2 +- tests/lax/test_scaled_ops.py | 43 ++++++++++-------------- 3 files changed, 21 insertions(+), 28 deletions(-) diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index 175c173..9b661c2 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -166,7 +166,7 @@ def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray: A, B = promote_scale_types(A, B) assert np.issubdtype(A.scale.dtype, np.floating) # TODO: what happens to `sqrt` for non-floating scale? - output_scale = lax.sqrt(A.scale**2 + B.scale**2) + output_scale = lax.sqrt(A.scale * A.scale + B.scale * B.scale) # Output dtype => promotion of A and B dtypes. outdtype = jnp.promote_types(A.dtype, B.dtype) Arescale = (A.scale / output_scale).astype(outdtype) @@ -188,7 +188,7 @@ def scaled_sub(A: ScaledArray, B: ScaledArray) -> ScaledArray: A, B = promote_scale_types(A, B) assert np.issubdtype(A.scale.dtype, np.floating) # TODO: what happens to `sqrt` for non-floating scale? - output_scale = lax.sqrt(A.scale**2 + B.scale**2) + output_scale = lax.sqrt(A.scale * A.scale + B.scale * B.scale) # Output dtype => promotion of A and B dtypes. outdtype = jnp.promote_types(A.dtype, B.dtype) Arescale = (A.scale / output_scale).astype(outdtype) diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index 06d87a1..c685050 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -171,7 +171,7 @@ def fn(x, y): npt.assert_array_almost_equal(scaled_primals, primals) npt.assert_array_almost_equal(scaled_tangents, tangents) - @chex.variants(with_jit=False, without_jit=True) + @chex.variants(with_jit=True, without_jit=True) def test__autoscale_decorator__custom_vjp__proper_graph_transformation_and_result(self): # JAX official `vjp` example. @jax.custom_vjp diff --git a/tests/lax/test_scaled_ops.py b/tests/lax/test_scaled_ops.py index 7dfa70a..afa106a 100644 --- a/tests/lax/test_scaled_ops.py +++ b/tests/lax/test_scaled_ops.py @@ -134,6 +134,7 @@ def setUp(self): # Use random state for reproducibility! self.rs = np.random.RandomState(42) + @chex.variants(with_jit=True, without_jit=True) @parameterized.parameters( {"prim": lax.exp_p, "dtype": np.float16, "expected_scale": 1.0}, # FIXME! {"prim": lax.log_p, "dtype": np.float16, "expected_scale": 1.0}, # FIXME! @@ -145,7 +146,7 @@ def setUp(self): def test__scaled_unary_op__proper_result_and_scaling(self, prim, dtype, expected_scale): scaled_op, _ = find_registered_scaled_op(prim) val = scaled_array(self.rs.rand(3, 5), 2.0, dtype=dtype) - out = scaled_op(val) + out = self.variant(scaled_op)(val) expected_output = prim.bind(np.asarray(val)) assert isinstance(out, ScaledArray) assert out.dtype == val.dtype @@ -160,35 +161,27 @@ def setUp(self): # Use random state for reproducibility! self.rs = np.random.RandomState(42) - @parameterized.parameters( - {"prim": lax.add_p, "dtype": np.float32}, - {"prim": lax.sub_p, "dtype": np.float32}, - {"prim": lax.mul_p, "dtype": np.float32}, - {"prim": lax.div_p, "dtype": np.float32}, - {"prim": lax.min_p, "dtype": np.float32}, - {"prim": lax.max_p, "dtype": np.float32}, - # Make sure type promotion is right! - {"prim": lax.add_p, "dtype": np.float16}, - {"prim": lax.sub_p, "dtype": np.float16}, - {"prim": lax.mul_p, "dtype": np.float16}, - {"prim": lax.div_p, "dtype": np.float16}, - {"prim": lax.min_p, "dtype": np.float16}, - {"prim": lax.max_p, "dtype": np.float16}, + @chex.variants(with_jit=True, without_jit=True) + @parameterized.product( + prim=[lax.add_p, lax.sub_p, lax.mul_p, lax.div_p, lax.min_p, lax.max_p], + dtype=[np.float16, np.float32], + sdtype=[np.float16, np.float32], ) - def test__scaled_binary_op__proper_result_and_promotion(self, prim, dtype): + def test__scaled_binary_op__proper_result_and_promotion(self, prim, dtype, sdtype): scaled_op, _ = find_registered_scaled_op(prim) - x = scaled_array([-1.0, 2.0], 3.0, dtype=dtype) - y = scaled_array([1.5, 4.5], 2.0, dtype=dtype) - # Always use float32 for scale factor. - assert x.scale.dtype == np.float32 - assert y.scale.dtype == np.float32 - - z = scaled_op(x, y) + # NOTE: direct construction to avoid weirdity between NumPy array and scalar! + x = ScaledArray(np.array([-1.0, 2.0], dtype), sdtype(3.0)) + y = ScaledArray(np.array([1.5, 4.5], dtype), sdtype(2.0)) + # Ensure scale factor has the right dtype. + assert x.scale.dtype == sdtype + assert y.scale.dtype == sdtype + + z = self.variant(scaled_op)(x, y) expected_z = prim.bind(np.asarray(x), np.asarray(y)) assert z.dtype == x.dtype - assert z.scale.dtype == x.scale.dtype - npt.assert_array_almost_equal(z, expected_z, decimal=5) + assert z.scale.dtype == sdtype + npt.assert_array_almost_equal(z, expected_z, decimal=3) @parameterized.parameters( {"prim": lax.add_p},