Skip to content

Commit

Permalink
Fix autoscale scalar promotion.
Browse files Browse the repository at this point in the history
Using `as_scalar_array` function for promotion, to keep rules
consistent (i.e. no promotion for int or bool).
  • Loading branch information
balancap committed Dec 4, 2023
1 parent b946582 commit 75c5520
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 18 deletions.
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def asarray_base(val: Any, dtype: DTypeLike = None) -> GenericArray:
"""Convert back to a common JAX/Numpy array, base function."""
if isinstance(val, ScaledArray):
return val.to_array(dtype=dtype)
elif isinstance(val, (Array, np.ndarray)):
elif isinstance(val, (*ArrayTypes, np.ndarray)):
if dtype is None:
return val
return val.astype(dtype=dtype)
Expand Down
13 changes: 3 additions & 10 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from jax._src.custom_derivatives import custom_jvp_call_jaxpr_p, custom_jvp_call_p, custom_vjp_call_p
from jax._src.util import safe_map

from .datatype import NDArray, ScaledArray, is_scaled_leaf
from .datatype import NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf


class ScaledPrimitiveType(IntEnum):
Expand Down Expand Up @@ -55,15 +55,8 @@ def promote_scalar_to_scaled_array(val: Any) -> ScaledArray:
Note: needs to work with any input type, including JAX tracer ones.
"""
# int / float special cases
if isinstance(val, float):
return ScaledArray(data=np.array(1, dtype=np.float32), scale=np.float32(val))
elif isinstance(val, int):
return ScaledArray(data=np.array(1, dtype=np.int32), scale=np.int32(val))
# Just a Numpy constant for data => can be optimized out in XLA compiler.
assert val.shape == ()
onedata = np.array(1, dtype=val.dtype)
return ScaledArray(data=onedata, scale=val)
# Use `as_scaled_array` promotion rules.
return as_scaled_array_base(val)


def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray:
Expand Down
12 changes: 9 additions & 3 deletions tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,16 @@ def test__as_scaled_array__complex_pytree(self):
npt.assert_array_equal(output["y"], input["y"])
assert output["z"] is input["z"]

@chex.variants(with_jit=True, without_jit=True)
@parameterized.parameters(
{"data": np.array(2)},
{"data": np.array([1, 2])},
{"data": np.int32(3)},
{"data": np.array(2, dtype=np.int32)},
{"data": np.array([1, 2], dtype=np.int32)},
{"data": np.array([1, 2.0], dtype=np.float32)},
{"data": jnp.array([1, 2])},
)
def test__asarray__unchanged_dtype(self, data):
output = asarray(data)
output = self.variant(asarray)(data)
assert output.dtype == data.dtype
npt.assert_array_almost_equal(output, data)

Expand All @@ -167,6 +170,9 @@ def test__asarray__complex_pytree(self):
npt.assert_array_almost_equal(output["x"], input["x"])
npt.assert_array_almost_equal(output["y"], input["y"])

# def test__asarray__tracing(self, data):
# pass

@parameterized.parameters(
{"val": 0, "result": True},
{"val": 0.0, "result": True},
Expand Down
18 changes: 14 additions & 4 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,24 @@ def fn(x, y):
npt.assert_array_almost_equal(scaled_tangents, tangents)

@parameterized.parameters(
{"input": np.array(3)},
{"input": jnp.array(3)},
{"input": 3},
{"input": 3.0},
{"input": np.float32(3.0)},
{"input": np.array(3.0)},
{"input": jnp.array(3.0)},
)
def test__promote_scalar_to_scaled_array__proper_output(self, input):
def test__promote_scalar_to_scaled_array__promoted_to_scaled_array(self, input):
scaled_val = promote_scalar_to_scaled_array(input)
assert isinstance(scaled_val, ScaledArray)
assert scaled_val.data.dtype == scaled_val.scale.dtype
npt.assert_array_equal(scaled_val.data, 1)
npt.assert_array_equal(scaled_val.scale, input)

@parameterized.parameters(
{"input": np.array(3)},
{"input": jnp.array(3)},
{"input": 3},
{"input": np.int32(2)},
)
def test__promote_scalar_to_scaled_array__not_promoted_to_scaled_array(self, input):
out = promote_scalar_to_scaled_array(input)
assert out is input

0 comments on commit 75c5520

Please sign in to comment.