From 1f8102d403eef2ae8f3b1589a547485691294c63 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 4 Dec 2023 15:49:43 +0000 Subject: [PATCH] Fix `autoscale` scalar promotion. Using `as_scalar_array` function for promotion, to keep rules consistent (i.e. no promotion for int or bool). --- jax_scaled_arithmetics/core/datatype.py | 2 +- jax_scaled_arithmetics/core/interpreters.py | 13 +++---------- tests/core/test_datatype.py | 9 ++++++--- tests/core/test_interpreter.py | 18 ++++++++++++++---- 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index 63170d8..47ff376 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -168,7 +168,7 @@ def asarray_base(val: Any, dtype: DTypeLike = None) -> GenericArray: """Convert back to a common JAX/Numpy array, base function.""" if isinstance(val, ScaledArray): return val.to_array(dtype=dtype) - elif isinstance(val, (Array, np.ndarray)): + elif isinstance(val, (*ArrayTypes, np.ndarray)): if dtype is None: return val return val.astype(dtype=dtype) diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index b5b66bc..b0948c1 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -9,7 +9,7 @@ from jax._src.custom_derivatives import custom_jvp_call_jaxpr_p, custom_jvp_call_p, custom_vjp_call_p from jax._src.util import safe_map -from .datatype import NDArray, ScaledArray, is_scaled_leaf +from .datatype import NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf class ScaledPrimitiveType(IntEnum): @@ -55,15 +55,8 @@ def promote_scalar_to_scaled_array(val: Any) -> ScaledArray: Note: needs to work with any input type, including JAX tracer ones. """ - # int / float special cases - if isinstance(val, float): - return ScaledArray(data=np.array(1, dtype=np.float32), scale=np.float32(val)) - elif isinstance(val, int): - return ScaledArray(data=np.array(1, dtype=np.int32), scale=np.int32(val)) - # Just a Numpy constant for data => can be optimized out in XLA compiler. - assert val.shape == () - onedata = np.array(1, dtype=val.dtype) - return ScaledArray(data=onedata, scale=val) + # Use `as_scaled_array` promotion rules. + return as_scaled_array_base(val) def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray: diff --git a/tests/core/test_datatype.py b/tests/core/test_datatype.py index 70fcc8d..8b0cf71 100644 --- a/tests/core/test_datatype.py +++ b/tests/core/test_datatype.py @@ -138,13 +138,16 @@ def test__as_scaled_array__complex_pytree(self): npt.assert_array_equal(output["y"], input["y"]) assert output["z"] is input["z"] + @chex.variants(with_jit=True, without_jit=True) @parameterized.parameters( - {"data": np.array(2)}, - {"data": np.array([1, 2])}, + {"data": np.int32(3)}, + {"data": np.array(2, dtype=np.int32)}, + {"data": np.array([1, 2], dtype=np.int32)}, + {"data": np.array([1, 2.0], dtype=np.float32)}, {"data": jnp.array([1, 2])}, ) def test__asarray__unchanged_dtype(self, data): - output = asarray(data) + output = self.variant(asarray)(data) assert output.dtype == data.dtype npt.assert_array_almost_equal(output, data) diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index fca3919..b24092b 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -172,14 +172,24 @@ def fn(x, y): npt.assert_array_almost_equal(scaled_tangents, tangents) @parameterized.parameters( - {"input": np.array(3)}, - {"input": jnp.array(3)}, - {"input": 3}, {"input": 3.0}, + {"input": np.float32(3.0)}, + {"input": np.array(3.0)}, + {"input": jnp.array(3.0)}, ) - def test__promote_scalar_to_scaled_array__proper_output(self, input): + def test__promote_scalar_to_scaled_array__promoted_to_scaled_array(self, input): scaled_val = promote_scalar_to_scaled_array(input) assert isinstance(scaled_val, ScaledArray) assert scaled_val.data.dtype == scaled_val.scale.dtype npt.assert_array_equal(scaled_val.data, 1) npt.assert_array_equal(scaled_val.scale, input) + + @parameterized.parameters( + {"input": np.array(3)}, + {"input": jnp.array(3)}, + {"input": 3}, + {"input": np.int32(2)}, + ) + def test__promote_scalar_to_scaled_array__not_promoted_to_scaled_array(self, input): + out = promote_scalar_to_scaled_array(input) + assert out is input