diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index 6fa6ea6..9f2f7fa 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -105,15 +105,15 @@ def scaled_pad(val: ScaledArray, padding_value: Any, padding_config: Any) -> Sca @core.register_scaled_lax_op -def scaled_argmax(val: ScaledArray, axis: int, index_dtype: DTypeLike) -> Array: +def scaled_argmax(val: ScaledArray, axes: Sequence[int], index_dtype: DTypeLike) -> Array: # Note: returning a normal `int` Array. - return lax.argmax(val.data, axis=axis, index_dtype=index_dtype) + return lax.argmax_p.bind(val.data, axes=axes, index_dtype=index_dtype) @core.register_scaled_lax_op -def scaled_argmin(val: ScaledArray, axis: int, index_dtype: DTypeLike) -> Array: +def scaled_argmin(val: ScaledArray, axes: Sequence[int], index_dtype: DTypeLike) -> Array: # Note: returning a normal `int` Array. - return lax.argmin(val.data, axis=axis, index_dtype=index_dtype) + return lax.argmin_p.bind(val.data, axes=axes, index_dtype=index_dtype) @core.register_scaled_lax_op diff --git a/tests/lax/test_scaled_ops.py b/tests/lax/test_scaled_ops.py index 6e24552..e6f6616 100644 --- a/tests/lax/test_scaled_ops.py +++ b/tests/lax/test_scaled_ops.py @@ -9,7 +9,6 @@ from jax_scaled_arithmetics.lax import ( scaled_abs, scaled_add, - scaled_argmax, scaled_broadcast_in_dim, scaled_concatenate, scaled_convert_element_type, @@ -105,12 +104,6 @@ def test__scaled_neg__proper_scaling(self): assert z.scale == x.scale npt.assert_array_almost_equal(z.data, -x.data) - def test__scaled_argmax__proper_scaling(self): - x = scaled_array(self.rs.rand(5), 2, dtype=np.float32) - z = scaled_argmax(x, 0, np.int32) - assert isinstance(z, Array) - npt.assert_array_equal(z, np.argmax(x.data)) - def test__scaled_abs__proper_scaling(self): x = scaled_array(self.rs.rand(3, 5), 2, dtype=np.float32) z = scaled_abs(x) @@ -198,6 +191,15 @@ def test__scaled_max__proper_scaling(self): assert isinstance(z, ScaledArray) npt.assert_array_almost_equal(z, np.maximum(x, y)) + @parameterized.parameters({"prim": lax.argmax_p}, {"prim": lax.argmin_p}) + def test__scaled_argminmax__proper_scaling(self, prim): + x = scaled_array(self.rs.rand(5), 2, dtype=np.float32) + expected_out = prim.bind(x.to_array(), axes=(0,), index_dtype=np.int32) + scaled_translation, _ = find_registered_scaled_op(prim) + out = scaled_translation(x, axes=(0,), index_dtype=np.int32) + assert isinstance(out, Array) + npt.assert_array_equal(out, expected_out) + class ScaledTranslationReducePrimitivesTests(chex.TestCase): def setUp(self):