Skip to content

Commit

Permalink
Improve as_scaled_array robustness on int/bool/scalars. (#45)
Browse files Browse the repository at this point in the history
Changing the default behaviour of `as_scaled_array` such that:
* int/bool inputs are returned as such;
* float scalars are properly converted, with value in the scale;
  • Loading branch information
balancap authored Dec 1, 2023
1 parent 44a0de2 commit f4e0331
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 12 deletions.
26 changes: 23 additions & 3 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,39 @@ def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npa
return scaled_array_base(data, scale, dtype, npapi)


def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> ScaledArray:
def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> Union[Array, ScaledArray]:
"""ScaledArray (helper) base factory method, similar to `(j)np.array`."""
scale = np.array(1, dtype=val.dtype) if scale is None else scale
if isinstance(val, ScaledArray):
return val
elif isinstance(val, (np.ndarray, Array)):

# FIXME: support general case.
assert scale is None or np.float32(scale) == np.float32(1) # type:ignore
# Trivial cases: bool, int, float.
if isinstance(val, (bool, int)):
return val
if isinstance(val, float):
return ScaledArray(np.array(1, dtype=np.float32), np.float32(val))

# Ignored dtypes by default: int and bool
ignored_dtype = np.issubdtype(val.dtype, np.integer) or np.issubdtype(val.dtype, np.bool_)
if ignored_dtype:
return val
# Floating point scalar
if val.ndim == 0:
return ScaledArray(np.array(1, dtype=val.dtype), val)

scale = np.array(1, dtype=val.dtype) if scale is None else scale
if isinstance(val, (np.ndarray, Array)):
return ScaledArray(val, scale)
return scaled_array_base(val, scale)


def as_scaled_array(val: Any, scale: Optional[ArrayLike] = None) -> ScaledArray:
"""ScaledArray (helper) factory method, similar to `(j)np.array`.
NOTE: by default, int and bool values/arrays will be returned unchanged, as
in most cases, there is no value representing these as scaled arrays.
Compatible with JAX PyTree.
Args:
Expand Down
39 changes: 30 additions & 9 deletions tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,28 +86,49 @@ def test__is_scaled_leaf__consistent_with_jax(self):
assert is_scaled_leaf(scaled_array(data=[1.0, 2.0], scale=3, dtype=np.float16))

@parameterized.parameters(
{"data": np.array(2)},
{"data": np.array([1, 2])},
{"data": jnp.array([1, 2])},
{"data": np.array([1, 2.0])},
{"data": jnp.array([1, 2.0])},
)
def test__as_scaled_array__unchanged_dtype(self, data):
output = as_scaled_array(data)
assert isinstance(output, ScaledArray)
assert isinstance(output.data, type(data))
assert output.dtype == data.dtype
assert output.dtype in {np.dtype(np.float32), np.dtype(np.float64)}
npt.assert_array_almost_equal(output, data)
npt.assert_array_equal(output.scale, np.array(1, dtype=data.dtype))
# unvariant when calling a second time.
assert as_scaled_array(output) is output

@parameterized.parameters(
{"data": 2.1},
{"data": np.float64(2.0)},
{"data": jnp.float32(2.0)},
)
def test__as_scaled_array__float_scalar(self, data):
output = as_scaled_array(data)
assert isinstance(output, ScaledArray)
assert output.data.dtype == output.scale.dtype
npt.assert_array_almost_equal(output.data, 1)
npt.assert_array_almost_equal(output.scale, data)

@parameterized.parameters(
{"data": False},
{"data": 2},
{"data": np.array([1, 2])},
{"data": jnp.array([1, 2])},
)
def test__as_scaled_array__unscaled_bool_int_output(self, data):
output = as_scaled_array(data)
assert output is data

def test__as_scaled_array__complex_pytree(self):
input = {"x": jnp.array([1, 2]), "y": as_scaled_array(jnp.array([1, 2]))}
input = {"x": jnp.array([1, 2]), "y": jnp.array([1.0, 2]), "z": as_scaled_array(jnp.array([1.0, 2]))}
output = as_scaled_array(input)
assert isinstance(output, dict)
assert len(output) == 2
assert isinstance(output["x"], ScaledArray)
npt.assert_array_equal(output["x"], input["x"])
assert output["y"] is input["y"]
assert len(output) == 3
assert output["x"] is input["x"]
npt.assert_array_equal(output["y"], input["y"])
assert output["z"] is input["z"]

@parameterized.parameters(
{"data": np.array(2)},
Expand Down

0 comments on commit f4e0331

Please sign in to comment.