diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index b0548a9..a2d00c4 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -60,6 +60,12 @@ def scaled_convert_element_type(A: ScaledArray, new_dtype: DTypeLike, weak_type: return ScaledArray(lax.convert_element_type(A.data, new_dtype=new_dtype), A.scale) +@core.register_scaled_lax_op +def scaled_reduce_precision(A: ScaledArray, exponent_bits: int, mantissa_bits: int) -> ScaledArray: + # Applying precision reduction only data term. + return ScaledArray(lax.reduce_precision(A.data, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits), A.scale) + + @core.register_scaled_lax_op def scaled_concatenate(operands: Sequence[ScaledArray], dimension: int) -> ScaledArray: # TODO: inputs checking (dtype and cie). diff --git a/tests/lax/test_scaled_ops.py b/tests/lax/test_scaled_ops.py index 42cb6e9..2dfe9b5 100644 --- a/tests/lax/test_scaled_ops.py +++ b/tests/lax/test_scaled_ops.py @@ -20,6 +20,7 @@ scaled_min, scaled_mul, scaled_neg, + scaled_reduce_precision, scaled_reshape, scaled_select_n, scaled_slice, @@ -70,6 +71,14 @@ def test__scaled_transpose__proper_scaling(self): assert z.scale == x.scale npt.assert_array_almost_equal(z.data, x.data.T) + def test__scaled_reduce_precision__proper_result(self): + x = scaled_array(self.rs.rand(3, 5), 2, dtype=np.float16) + # Reduction to pseudo FP8 format. + z = scaled_reduce_precision(x, exponent_bits=4, mantissa_bits=3) + assert isinstance(z, ScaledArray) + assert z.scale == x.scale + npt.assert_array_almost_equal(z.data, lax.reduce_precision(x.data, exponent_bits=4, mantissa_bits=3)) + def test__scaled_neg__proper_scaling(self): x = scaled_array(self.rs.rand(3, 5), 2, dtype=np.float32) z = scaled_neg(x)