Skip to content

Commit

Permalink
Fix autoscale scalar promotion.
Browse files Browse the repository at this point in the history
Using `as_scalar_array` function for promotion, to keep rules
consistent (i.e. no promotion for int or bool).
  • Loading branch information
balancap committed Dec 4, 2023
1 parent b946582 commit 85fb76c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
13 changes: 3 additions & 10 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from jax._src.custom_derivatives import custom_jvp_call_jaxpr_p, custom_jvp_call_p, custom_vjp_call_p
from jax._src.util import safe_map

from .datatype import NDArray, ScaledArray, is_scaled_leaf
from .datatype import NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf


class ScaledPrimitiveType(IntEnum):
Expand Down Expand Up @@ -55,15 +55,8 @@ def promote_scalar_to_scaled_array(val: Any) -> ScaledArray:
Note: needs to work with any input type, including JAX tracer ones.
"""
# int / float special cases
if isinstance(val, float):
return ScaledArray(data=np.array(1, dtype=np.float32), scale=np.float32(val))
elif isinstance(val, int):
return ScaledArray(data=np.array(1, dtype=np.int32), scale=np.int32(val))
# Just a Numpy constant for data => can be optimized out in XLA compiler.
assert val.shape == ()
onedata = np.array(1, dtype=val.dtype)
return ScaledArray(data=onedata, scale=val)
# Use `as_scaled_array` promotion rules.
return as_scaled_array_base(val)


def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray:
Expand Down
18 changes: 14 additions & 4 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,24 @@ def fn(x, y):
npt.assert_array_almost_equal(scaled_tangents, tangents)

@parameterized.parameters(
{"input": np.array(3)},
{"input": jnp.array(3)},
{"input": 3},
{"input": 3.0},
{"input": np.float32(3.0)},
{"input": np.array(3.0)},
{"input": jnp.array(3.0)},
)
def test__promote_scalar_to_scaled_array__proper_output(self, input):
def test__promote_scalar_to_scaled_array__promoted_to_scaled_array(self, input):
scaled_val = promote_scalar_to_scaled_array(input)
assert isinstance(scaled_val, ScaledArray)
assert scaled_val.data.dtype == scaled_val.scale.dtype
npt.assert_array_equal(scaled_val.data, 1)
npt.assert_array_equal(scaled_val.scale, input)

@parameterized.parameters(
{"input": np.array(3)},
{"input": jnp.array(3)},
{"input": 3},
{"input": np.int32(2)},
)
def test__promote_scalar_to_scaled_array__not_promoted_to_scaled_array(self, input):
out = promote_scalar_to_scaled_array(input)
assert out is input

0 comments on commit 85fb76c

Please sign in to comment.