diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 9a26867..8505da8 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -199,7 +199,7 @@ def promote_to_scaled_array(val): ) else: # Using scaled primitive. Automatic promotion of inputs to scaled array, when possible. - invals = map(promote_to_scaled_array, invals) + invals = list(map(promote_to_scaled_array, invals)) outvals = scaled_prim_fn(*invals, **eqn.params) if not eqn.primitive.multiple_results: diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index badd160..1cfc90e 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -6,9 +6,10 @@ import jax.numpy as jnp import numpy as np from jax import lax +from jax._src.ad_util import add_any_p from jax_scaled_arithmetics import core -from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, Shape +from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, Shape, register_scaled_op from .base_scaling_primitives import scaled_set_scaling @@ -41,6 +42,11 @@ def scaled_stop_gradient(val: ScaledArray) -> ScaledArray: return ScaledArray(lax.stop_gradient(val.data), lax.stop_gradient(val.scale)) +@core.register_scaled_lax_op +def scaled_reshape(A: ScaledArray, new_sizes: Sequence[int], dimensions: Optional[Sequence[int]]) -> ScaledArray: + return ScaledArray(lax.reshape(A.data, new_sizes=new_sizes, dimensions=dimensions), A.scale) + + @core.register_scaled_lax_op def scaled_broadcast_in_dim(A: ScaledArray, shape: Shape, broadcast_dimensions: Sequence[int]) -> ScaledArray: return ScaledArray(lax.broadcast_in_dim(A.data, shape=shape, broadcast_dimensions=broadcast_dimensions), A.scale) @@ -81,8 +87,19 @@ def scaled_transpose(A: ScaledArray, permutation: Sequence[int]) -> ScaledArray: @core.register_scaled_lax_op -def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray: - return ScaledArray(A.data * B.data, A.scale * B.scale) +def scaled_neg(val: ScaledArray) -> ScaledArray: + return ScaledArray(-val.data, val.scale) + + +@core.register_scaled_lax_op +def scaled_mul(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray: + return ScaledArray(lhs.data * rhs.data, lhs.scale * rhs.scale) + + +@core.register_scaled_lax_op +def scaled_div(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray: + # TODO: investigate different rule? + return ScaledArray(lhs.data / rhs.data, lhs.scale / rhs.scale) @core.register_scaled_lax_op @@ -97,6 +114,10 @@ def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray: return ScaledArray(output_data, output_scale) +# TODO: understand difference between `add` and `add_anys` +register_scaled_op(add_any_p, scaled_add) + + @core.register_scaled_lax_op def scaled_sub(A: ScaledArray, B: ScaledArray) -> ScaledArray: check_scalar_scales(A, B) @@ -274,3 +295,13 @@ def scaled_cos(val: ScaledArray) -> ScaledArray: @core.register_scaled_lax_op def scaled_sin(val: ScaledArray) -> ScaledArray: return scaled_op_default_translation(lax.sin_p, [val]) + + +@core.register_scaled_lax_op +def scaled_min(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray: + return scaled_op_default_translation(lax.min_p, [lhs, rhs]) + + +@core.register_scaled_lax_op +def scaled_max(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray: + return scaled_op_default_translation(lax.max_p, [lhs, rhs]) diff --git a/tests/lax/test_scaled_ops.py b/tests/lax/test_scaled_ops.py index bfb68e2..527928e 100644 --- a/tests/lax/test_scaled_ops.py +++ b/tests/lax/test_scaled_ops.py @@ -12,11 +12,16 @@ scaled_broadcast_in_dim, scaled_concatenate, scaled_convert_element_type, + scaled_div, scaled_dot_general, scaled_exp, scaled_is_finite, scaled_log, + scaled_max, + scaled_min, scaled_mul, + scaled_neg, + scaled_reshape, scaled_select_n, scaled_slice, scaled_sub, @@ -37,6 +42,13 @@ def test__scaled_broadcast_in_dim__proper_scaling(self): npt.assert_array_equal(z.scale, x.scale) npt.assert_array_almost_equal(z.data, x.data.reshape((5, 1))) + def test__scaled_reshape__proper_scaling(self): + x = scaled_array(self.rs.rand(8), 2, dtype=np.float32) + z = scaled_reshape(x, new_sizes=(4, 2), dimensions=None) + assert isinstance(z, ScaledArray) + npt.assert_array_equal(z.scale, x.scale) + npt.assert_array_almost_equal(z.data, x.data.reshape((4, 2))) + def test__scaled_concatenate__proper_scaling(self): x = scaled_array(self.rs.rand(2, 3), 0.5, dtype=np.float32) y = scaled_array(self.rs.rand(5, 3), 2, dtype=np.float32) @@ -59,6 +71,13 @@ def test__scaled_transpose__proper_scaling(self): assert z.scale == x.scale npt.assert_array_almost_equal(z.data, x.data.T) + def test__scaled_neg__proper_scaling(self): + x = scaled_array(self.rs.rand(3, 5), 2, dtype=np.float32) + z = scaled_neg(x) + assert isinstance(z, ScaledArray) + assert z.scale == x.scale + npt.assert_array_almost_equal(z.data, -x.data) + def test__scaled_slice__proper_scaling(self): x = scaled_array(self.rs.rand(5), 2, dtype=np.float32) z = scaled_slice(x, (1,), (4,), (2,)) @@ -74,6 +93,14 @@ def test__scaled_mul__proper_scaling(self): assert z.scale == 6 npt.assert_array_almost_equal(z, np.asarray(x) * np.asarray(y)) + def test__scaled_div__proper_scaling(self): + x = scaled_array([-2.0, 2.0], 3.0, dtype=np.float32) + y = scaled_array([1.5, 1.5], 2.0, dtype=np.float32) + z = scaled_div(x, y) + assert isinstance(z, ScaledArray) + assert z.scale == 1.5 + npt.assert_array_almost_equal(z, np.asarray(x) / np.asarray(y)) + def test__scaled_add__proper_scaling(self): x = scaled_array([-1.0, 2.0], 3.0, dtype=np.float32) y = scaled_array([1.5, 4.5], 2.0, dtype=np.float32) @@ -117,6 +144,20 @@ def test__scaled_log__proper_scaling(self): npt.assert_almost_equal(out.scale, 1) # FIXME! npt.assert_array_almost_equal(out, np.log(val)) + def test__scaled_min__proper_scaling(self): + x = scaled_array([-2.0, 2.0], 3, dtype=np.float32) + y = scaled_array([1.5, 1.5], 2, dtype=np.float32) + z = scaled_min(x, y) + assert isinstance(z, ScaledArray) + npt.assert_array_almost_equal(z, np.minimum(x, y)) + + def test__scaled_max__proper_scaling(self): + x = scaled_array([-2.0, 2.0], 3, dtype=np.float32) + y = scaled_array([1.5, 1.5], 2, dtype=np.float32) + z = scaled_max(x, y) + assert isinstance(z, ScaledArray) + npt.assert_array_almost_equal(z, np.maximum(x, y)) + class ScaledTranslationReducePrimitivesTests(chex.TestCase): def setUp(self):