Skip to content

Commit

Permalink
Fix scaled argminmax parameter naming. (#48)
Browse files Browse the repository at this point in the history
Classic case of decrepancy between JAX LAX API, and primitives parameters.
  • Loading branch information
balancap authored Dec 4, 2023
1 parent 5ae6190 commit b946582
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
8 changes: 4 additions & 4 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def scaled_pad(val: ScaledArray, padding_value: Any, padding_config: Any) -> Sca


@core.register_scaled_lax_op
def scaled_argmax(val: ScaledArray, axis: int, index_dtype: DTypeLike) -> Array:
def scaled_argmax(val: ScaledArray, axes: Sequence[int], index_dtype: DTypeLike) -> Array:
# Note: returning a normal `int` Array.
return lax.argmax(val.data, axis=axis, index_dtype=index_dtype)
return lax.argmax_p.bind(val.data, axes=axes, index_dtype=index_dtype)


@core.register_scaled_lax_op
def scaled_argmin(val: ScaledArray, axis: int, index_dtype: DTypeLike) -> Array:
def scaled_argmin(val: ScaledArray, axes: Sequence[int], index_dtype: DTypeLike) -> Array:
# Note: returning a normal `int` Array.
return lax.argmin(val.data, axis=axis, index_dtype=index_dtype)
return lax.argmin_p.bind(val.data, axes=axes, index_dtype=index_dtype)


@core.register_scaled_lax_op
Expand Down
16 changes: 9 additions & 7 deletions tests/lax/test_scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from jax_scaled_arithmetics.lax import (
scaled_abs,
scaled_add,
scaled_argmax,
scaled_broadcast_in_dim,
scaled_concatenate,
scaled_convert_element_type,
Expand Down Expand Up @@ -105,12 +104,6 @@ def test__scaled_neg__proper_scaling(self):
assert z.scale == x.scale
npt.assert_array_almost_equal(z.data, -x.data)

def test__scaled_argmax__proper_scaling(self):
x = scaled_array(self.rs.rand(5), 2, dtype=np.float32)
z = scaled_argmax(x, 0, np.int32)
assert isinstance(z, Array)
npt.assert_array_equal(z, np.argmax(x.data))

def test__scaled_abs__proper_scaling(self):
x = scaled_array(self.rs.rand(3, 5), 2, dtype=np.float32)
z = scaled_abs(x)
Expand Down Expand Up @@ -198,6 +191,15 @@ def test__scaled_max__proper_scaling(self):
assert isinstance(z, ScaledArray)
npt.assert_array_almost_equal(z, np.maximum(x, y))

@parameterized.parameters({"prim": lax.argmax_p}, {"prim": lax.argmin_p})
def test__scaled_argminmax__proper_scaling(self, prim):
x = scaled_array(self.rs.rand(5), 2, dtype=np.float32)
expected_out = prim.bind(x.to_array(), axes=(0,), index_dtype=np.int32)
scaled_translation, _ = find_registered_scaled_op(prim)
out = scaled_translation(x, axes=(0,), index_dtype=np.int32)
assert isinstance(out, Array)
npt.assert_array_equal(out, expected_out)


class ScaledTranslationReducePrimitivesTests(chex.TestCase):
def setUp(self):
Expand Down

0 comments on commit b946582

Please sign in to comment.