Skip to content

Commit

Permalink
Additional scale promotion unit test coverage in binary ops. (#55)
Browse files Browse the repository at this point in the history
Help catching & solving a bug around NumPy handling of `x**2`,
dependent on where `x` is a scalar or an array.
  • Loading branch information
balancap authored Dec 14, 2023
1 parent 51f6452 commit 318bd41
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 28 deletions.
4 changes: 2 additions & 2 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray:
A, B = promote_scale_types(A, B)
assert np.issubdtype(A.scale.dtype, np.floating)
# TODO: what happens to `sqrt` for non-floating scale?
output_scale = lax.sqrt(A.scale**2 + B.scale**2)
output_scale = lax.sqrt(A.scale * A.scale + B.scale * B.scale)
# Output dtype => promotion of A and B dtypes.
outdtype = jnp.promote_types(A.dtype, B.dtype)
Arescale = (A.scale / output_scale).astype(outdtype)
Expand All @@ -188,7 +188,7 @@ def scaled_sub(A: ScaledArray, B: ScaledArray) -> ScaledArray:
A, B = promote_scale_types(A, B)
assert np.issubdtype(A.scale.dtype, np.floating)
# TODO: what happens to `sqrt` for non-floating scale?
output_scale = lax.sqrt(A.scale**2 + B.scale**2)
output_scale = lax.sqrt(A.scale * A.scale + B.scale * B.scale)
# Output dtype => promotion of A and B dtypes.
outdtype = jnp.promote_types(A.dtype, B.dtype)
Arescale = (A.scale / output_scale).astype(outdtype)
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def fn(x, y):
npt.assert_array_almost_equal(scaled_primals, primals)
npt.assert_array_almost_equal(scaled_tangents, tangents)

@chex.variants(with_jit=False, without_jit=True)
@chex.variants(with_jit=True, without_jit=True)
def test__autoscale_decorator__custom_vjp__proper_graph_transformation_and_result(self):
# JAX official `vjp` example.
@jax.custom_vjp
Expand Down
43 changes: 18 additions & 25 deletions tests/lax/test_scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def setUp(self):
# Use random state for reproducibility!
self.rs = np.random.RandomState(42)

@chex.variants(with_jit=True, without_jit=True)
@parameterized.parameters(
{"prim": lax.exp_p, "dtype": np.float16, "expected_scale": 1.0}, # FIXME!
{"prim": lax.log_p, "dtype": np.float16, "expected_scale": 1.0}, # FIXME!
Expand All @@ -145,7 +146,7 @@ def setUp(self):
def test__scaled_unary_op__proper_result_and_scaling(self, prim, dtype, expected_scale):
scaled_op, _ = find_registered_scaled_op(prim)
val = scaled_array(self.rs.rand(3, 5), 2.0, dtype=dtype)
out = scaled_op(val)
out = self.variant(scaled_op)(val)
expected_output = prim.bind(np.asarray(val))
assert isinstance(out, ScaledArray)
assert out.dtype == val.dtype
Expand All @@ -160,35 +161,27 @@ def setUp(self):
# Use random state for reproducibility!
self.rs = np.random.RandomState(42)

@parameterized.parameters(
{"prim": lax.add_p, "dtype": np.float32},
{"prim": lax.sub_p, "dtype": np.float32},
{"prim": lax.mul_p, "dtype": np.float32},
{"prim": lax.div_p, "dtype": np.float32},
{"prim": lax.min_p, "dtype": np.float32},
{"prim": lax.max_p, "dtype": np.float32},
# Make sure type promotion is right!
{"prim": lax.add_p, "dtype": np.float16},
{"prim": lax.sub_p, "dtype": np.float16},
{"prim": lax.mul_p, "dtype": np.float16},
{"prim": lax.div_p, "dtype": np.float16},
{"prim": lax.min_p, "dtype": np.float16},
{"prim": lax.max_p, "dtype": np.float16},
@chex.variants(with_jit=True, without_jit=True)
@parameterized.product(
prim=[lax.add_p, lax.sub_p, lax.mul_p, lax.div_p, lax.min_p, lax.max_p],
dtype=[np.float16, np.float32],
sdtype=[np.float16, np.float32],
)
def test__scaled_binary_op__proper_result_and_promotion(self, prim, dtype):
def test__scaled_binary_op__proper_result_and_promotion(self, prim, dtype, sdtype):
scaled_op, _ = find_registered_scaled_op(prim)
x = scaled_array([-1.0, 2.0], 3.0, dtype=dtype)
y = scaled_array([1.5, 4.5], 2.0, dtype=dtype)
# Always use float32 for scale factor.
assert x.scale.dtype == np.float32
assert y.scale.dtype == np.float32

z = scaled_op(x, y)
# NOTE: direct construction to avoid weirdity between NumPy array and scalar!
x = ScaledArray(np.array([-1.0, 2.0], dtype), sdtype(3.0))
y = ScaledArray(np.array([1.5, 4.5], dtype), sdtype(2.0))
# Ensure scale factor has the right dtype.
assert x.scale.dtype == sdtype
assert y.scale.dtype == sdtype

z = self.variant(scaled_op)(x, y)
expected_z = prim.bind(np.asarray(x), np.asarray(y))

assert z.dtype == x.dtype
assert z.scale.dtype == x.scale.dtype
npt.assert_array_almost_equal(z, expected_z, decimal=5)
assert z.scale.dtype == sdtype
npt.assert_array_almost_equal(z, expected_z, decimal=3)

@parameterized.parameters(
{"prim": lax.add_p},
Expand Down

0 comments on commit 318bd41

Please sign in to comment.