Skip to content

Commit

Permalink
Support properly scale parameter in as_scaled_array function. (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap authored Dec 19, 2023
1 parent ea122af commit 591645c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
17 changes: 10 additions & 7 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,25 +129,28 @@ def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> Union[A
if isinstance(val, ScaledArray):
return val

# FIXME: support general case.
assert scale is None or np.float32(scale) == np.float32(1) # type:ignore
# Simple case => when can ignore the scaling factor (i.e. 1 implicitely).
is_static_one_scale: bool = scale is None or is_static_one_scalar(scale) # type:ignore
# Trivial cases: bool, int, float.
if isinstance(val, (bool, int)):
if is_static_one_scale and isinstance(val, (bool, int)):
return val
if isinstance(val, float):
if is_static_one_scale and isinstance(val, float):
return ScaledArray(np.array(1, dtype=np.float32), np.float32(val))

# Ignored dtypes by default: int and bool
ignored_dtype = np.issubdtype(val.dtype, np.integer) or np.issubdtype(val.dtype, np.bool_)
if ignored_dtype:
return val
# Floating point scalar
if val.ndim == 0:
if val.ndim == 0 and is_static_one_scale:
return ScaledArray(np.array(1, dtype=val.dtype), val)

scale = np.array(1, dtype=val.dtype) if scale is None else scale
if isinstance(val, (np.ndarray, Array)):
return ScaledArray(val, scale)
if is_static_one_scale:
return ScaledArray(val, scale)
else:
return ScaledArray(val / scale.astype(val.dtype), scale) # type:ignore
return scaled_array_base(val, scale)


Expand All @@ -165,7 +168,7 @@ def as_scaled_array(val: Any, scale: Optional[ArrayLike] = None) -> ScaledArray:
Returns:
Scaled array instance.
"""
return jax.tree_map(lambda x: as_scaled_array_base(x, None), val, is_leaf=is_scaled_leaf)
return jax.tree_map(lambda x: as_scaled_array_base(x, scale), val, is_leaf=is_scaled_leaf)


def asarray_base(val: Any, dtype: DTypeLike = None) -> GenericArray:
Expand Down
11 changes: 11 additions & 0 deletions tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,17 @@ def test__as_scaled_array__float_scalar(self, data):
npt.assert_array_almost_equal(output.data, 1)
npt.assert_array_almost_equal(output.scale, data)

@parameterized.parameters(
{"data": jnp.float32(3.0)},
)
def test__as_scaled_array__optional_scalar(self, data):
scale = np.float16(2)
output = as_scaled_array(data, scale=scale)
assert isinstance(output, ScaledArray)
assert output.scale.dtype == np.float16
npt.assert_array_equal(output.scale, np.array(2, dtype=np.float16))
npt.assert_array_almost_equal(output, data)

@parameterized.parameters(
{"data": False},
{"data": 2},
Expand Down

0 comments on commit 591645c

Please sign in to comment.