diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index b5b66bc..b0948c1 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -9,7 +9,7 @@ from jax._src.custom_derivatives import custom_jvp_call_jaxpr_p, custom_jvp_call_p, custom_vjp_call_p from jax._src.util import safe_map -from .datatype import NDArray, ScaledArray, is_scaled_leaf +from .datatype import NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf class ScaledPrimitiveType(IntEnum): @@ -55,15 +55,8 @@ def promote_scalar_to_scaled_array(val: Any) -> ScaledArray: Note: needs to work with any input type, including JAX tracer ones. """ - # int / float special cases - if isinstance(val, float): - return ScaledArray(data=np.array(1, dtype=np.float32), scale=np.float32(val)) - elif isinstance(val, int): - return ScaledArray(data=np.array(1, dtype=np.int32), scale=np.int32(val)) - # Just a Numpy constant for data => can be optimized out in XLA compiler. - assert val.shape == () - onedata = np.array(1, dtype=val.dtype) - return ScaledArray(data=onedata, scale=val) + # Use `as_scaled_array` promotion rules. + return as_scaled_array_base(val) def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray: diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index fca3919..b24092b 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -172,14 +172,24 @@ def fn(x, y): npt.assert_array_almost_equal(scaled_tangents, tangents) @parameterized.parameters( - {"input": np.array(3)}, - {"input": jnp.array(3)}, - {"input": 3}, {"input": 3.0}, + {"input": np.float32(3.0)}, + {"input": np.array(3.0)}, + {"input": jnp.array(3.0)}, ) - def test__promote_scalar_to_scaled_array__proper_output(self, input): + def test__promote_scalar_to_scaled_array__promoted_to_scaled_array(self, input): scaled_val = promote_scalar_to_scaled_array(input) assert isinstance(scaled_val, ScaledArray) assert scaled_val.data.dtype == scaled_val.scale.dtype npt.assert_array_equal(scaled_val.data, 1) npt.assert_array_equal(scaled_val.scale, input) + + @parameterized.parameters( + {"input": np.array(3)}, + {"input": jnp.array(3)}, + {"input": 3}, + {"input": np.int32(2)}, + ) + def test__promote_scalar_to_scaled_array__not_promoted_to_scaled_array(self, input): + out = promote_scalar_to_scaled_array(input) + assert out is input