diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index 369b881..8857bd1 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -120,12 +120,29 @@ def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npa return scaled_array_base(data, scale, dtype, npapi) -def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> ScaledArray: +def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> Union[Array, ScaledArray]: """ScaledArray (helper) base factory method, similar to `(j)np.array`.""" - scale = np.array(1, dtype=val.dtype) if scale is None else scale if isinstance(val, ScaledArray): return val - elif isinstance(val, (np.ndarray, Array)): + + # FIXME: support general case. + assert scale is None or np.float32(scale) == np.float32(1) # type:ignore + # Trivial cases: bool, int, float. + if isinstance(val, (bool, int)): + return val + if isinstance(val, float): + return ScaledArray(np.array(1, dtype=np.float32), np.float32(val)) + + # Ignored dtypes by default: int and bool + ignored_dtype = np.issubdtype(val.dtype, np.integer) or np.issubdtype(val.dtype, np.bool_) + if ignored_dtype: + return val + # Floating point scalar + if val.ndim == 0: + return ScaledArray(np.array(1, dtype=val.dtype), val) + + scale = np.array(1, dtype=val.dtype) if scale is None else scale + if isinstance(val, (np.ndarray, Array)): return ScaledArray(val, scale) return scaled_array_base(val, scale) @@ -133,6 +150,9 @@ def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> ScaledA def as_scaled_array(val: Any, scale: Optional[ArrayLike] = None) -> ScaledArray: """ScaledArray (helper) factory method, similar to `(j)np.array`. + NOTE: by default, int and bool values/arrays will be returned unchanged, as + in most cases, there is no value representing these as scaled arrays. + Compatible with JAX PyTree. Args: diff --git a/tests/core/test_datatype.py b/tests/core/test_datatype.py index e8f0010..d54ae2c 100644 --- a/tests/core/test_datatype.py +++ b/tests/core/test_datatype.py @@ -86,28 +86,49 @@ def test__is_scaled_leaf__consistent_with_jax(self): assert is_scaled_leaf(scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float16)) @parameterized.parameters( - {"data": np.array(2)}, - {"data": np.array([1, 2])}, - {"data": jnp.array([1, 2])}, + {"data": np.array([1, 2.0])}, + {"data": jnp.array([1, 2.0])}, ) def test__as_scaled_array__unchanged_dtype(self, data): output = as_scaled_array(data) assert isinstance(output, ScaledArray) assert isinstance(output.data, type(data)) - assert output.dtype == data.dtype + assert output.dtype in {np.dtype(np.float32), np.dtype(np.float64)} npt.assert_array_almost_equal(output, data) npt.assert_array_equal(output.scale, np.array(1, dtype=data.dtype)) # unvariant when calling a second time. assert as_scaled_array(output) is output + @parameterized.parameters( + {"data": 2.1}, + {"data": np.float64(2.0)}, + {"data": jnp.float32(2.0)}, + ) + def test__as_scaled_array__float_scalar(self, data): + output = as_scaled_array(data) + assert isinstance(output, ScaledArray) + assert output.data.dtype == output.scale.dtype + npt.assert_array_almost_equal(output.data, 1) + npt.assert_array_almost_equal(output.scale, data) + + @parameterized.parameters( + {"data": False}, + {"data": 2}, + {"data": np.array([1, 2])}, + {"data": jnp.array([1, 2])}, + ) + def test__as_scaled_array__unscaled_bool_int_output(self, data): + output = as_scaled_array(data) + assert output is data + def test__as_scaled_array__complex_pytree(self): - input = {"x": jnp.array([1, 2]), "y": as_scaled_array(jnp.array([1, 2]))} + input = {"x": jnp.array([1, 2]), "y": jnp.array([1.0, 2]), "z": as_scaled_array(jnp.array([1.0, 2]))} output = as_scaled_array(input) assert isinstance(output, dict) - assert len(output) == 2 - assert isinstance(output["x"], ScaledArray) - npt.assert_array_equal(output["x"], input["x"]) - assert output["y"] is input["y"] + assert len(output) == 3 + assert output["x"] is input["x"] + npt.assert_array_equal(output["y"], input["y"]) + assert output["z"] is input["z"] @parameterized.parameters( {"data": np.array(2)},