Skip to content

Commit

Permalink
Additional simple scaled operations coverage. (#35)
Browse files Browse the repository at this point in the history
Operations added: `reshape`, `div`, `min`, `max`.
  • Loading branch information
balancap authored Nov 25, 2023
1 parent ba6115a commit 1ec9dd6
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 4 deletions.
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def promote_to_scaled_array(val):
)
else:
# Using scaled primitive. Automatic promotion of inputs to scaled array, when possible.
invals = map(promote_to_scaled_array, invals)
invals = list(map(promote_to_scaled_array, invals))
outvals = scaled_prim_fn(*invals, **eqn.params)

if not eqn.primitive.multiple_results:
Expand Down
37 changes: 34 additions & 3 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax._src.ad_util import add_any_p

from jax_scaled_arithmetics import core
from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, Shape
from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, Shape, register_scaled_op

from .base_scaling_primitives import scaled_set_scaling

Expand Down Expand Up @@ -41,6 +42,11 @@ def scaled_stop_gradient(val: ScaledArray) -> ScaledArray:
return ScaledArray(lax.stop_gradient(val.data), lax.stop_gradient(val.scale))


@core.register_scaled_lax_op
def scaled_reshape(A: ScaledArray, new_sizes: Sequence[int], dimensions: Optional[Sequence[int]]) -> ScaledArray:
return ScaledArray(lax.reshape(A.data, new_sizes=new_sizes, dimensions=dimensions), A.scale)


@core.register_scaled_lax_op
def scaled_broadcast_in_dim(A: ScaledArray, shape: Shape, broadcast_dimensions: Sequence[int]) -> ScaledArray:
return ScaledArray(lax.broadcast_in_dim(A.data, shape=shape, broadcast_dimensions=broadcast_dimensions), A.scale)
Expand Down Expand Up @@ -81,8 +87,19 @@ def scaled_transpose(A: ScaledArray, permutation: Sequence[int]) -> ScaledArray:


@core.register_scaled_lax_op
def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(A.data * B.data, A.scale * B.scale)
def scaled_neg(val: ScaledArray) -> ScaledArray:
return ScaledArray(-val.data, val.scale)


@core.register_scaled_lax_op
def scaled_mul(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
return ScaledArray(lhs.data * rhs.data, lhs.scale * rhs.scale)


@core.register_scaled_lax_op
def scaled_div(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
# TODO: investigate different rule?
return ScaledArray(lhs.data / rhs.data, lhs.scale / rhs.scale)


@core.register_scaled_lax_op
Expand All @@ -97,6 +114,10 @@ def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(output_data, output_scale)


# TODO: understand difference between `add` and `add_anys`
register_scaled_op(add_any_p, scaled_add)


@core.register_scaled_lax_op
def scaled_sub(A: ScaledArray, B: ScaledArray) -> ScaledArray:
check_scalar_scales(A, B)
Expand Down Expand Up @@ -274,3 +295,13 @@ def scaled_cos(val: ScaledArray) -> ScaledArray:
@core.register_scaled_lax_op
def scaled_sin(val: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.sin_p, [val])


@core.register_scaled_lax_op
def scaled_min(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.min_p, [lhs, rhs])


@core.register_scaled_lax_op
def scaled_max(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.max_p, [lhs, rhs])
41 changes: 41 additions & 0 deletions tests/lax/test_scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@
scaled_broadcast_in_dim,
scaled_concatenate,
scaled_convert_element_type,
scaled_div,
scaled_dot_general,
scaled_exp,
scaled_is_finite,
scaled_log,
scaled_max,
scaled_min,
scaled_mul,
scaled_neg,
scaled_reshape,
scaled_select_n,
scaled_slice,
scaled_sub,
Expand All @@ -37,6 +42,13 @@ def test__scaled_broadcast_in_dim__proper_scaling(self):
npt.assert_array_equal(z.scale, x.scale)
npt.assert_array_almost_equal(z.data, x.data.reshape((5, 1)))

def test__scaled_reshape__proper_scaling(self):
x = scaled_array(self.rs.rand(8), 2, dtype=np.float32)
z = scaled_reshape(x, new_sizes=(4, 2), dimensions=None)
assert isinstance(z, ScaledArray)
npt.assert_array_equal(z.scale, x.scale)
npt.assert_array_almost_equal(z.data, x.data.reshape((4, 2)))

def test__scaled_concatenate__proper_scaling(self):
x = scaled_array(self.rs.rand(2, 3), 0.5, dtype=np.float32)
y = scaled_array(self.rs.rand(5, 3), 2, dtype=np.float32)
Expand All @@ -59,6 +71,13 @@ def test__scaled_transpose__proper_scaling(self):
assert z.scale == x.scale
npt.assert_array_almost_equal(z.data, x.data.T)

def test__scaled_neg__proper_scaling(self):
x = scaled_array(self.rs.rand(3, 5), 2, dtype=np.float32)
z = scaled_neg(x)
assert isinstance(z, ScaledArray)
assert z.scale == x.scale
npt.assert_array_almost_equal(z.data, -x.data)

def test__scaled_slice__proper_scaling(self):
x = scaled_array(self.rs.rand(5), 2, dtype=np.float32)
z = scaled_slice(x, (1,), (4,), (2,))
Expand All @@ -74,6 +93,14 @@ def test__scaled_mul__proper_scaling(self):
assert z.scale == 6
npt.assert_array_almost_equal(z, np.asarray(x) * np.asarray(y))

def test__scaled_div__proper_scaling(self):
x = scaled_array([-2.0, 2.0], 3.0, dtype=np.float32)
y = scaled_array([1.5, 1.5], 2.0, dtype=np.float32)
z = scaled_div(x, y)
assert isinstance(z, ScaledArray)
assert z.scale == 1.5
npt.assert_array_almost_equal(z, np.asarray(x) / np.asarray(y))

def test__scaled_add__proper_scaling(self):
x = scaled_array([-1.0, 2.0], 3.0, dtype=np.float32)
y = scaled_array([1.5, 4.5], 2.0, dtype=np.float32)
Expand Down Expand Up @@ -117,6 +144,20 @@ def test__scaled_log__proper_scaling(self):
npt.assert_almost_equal(out.scale, 1) # FIXME!
npt.assert_array_almost_equal(out, np.log(val))

def test__scaled_min__proper_scaling(self):
x = scaled_array([-2.0, 2.0], 3, dtype=np.float32)
y = scaled_array([1.5, 1.5], 2, dtype=np.float32)
z = scaled_min(x, y)
assert isinstance(z, ScaledArray)
npt.assert_array_almost_equal(z, np.minimum(x, y))

def test__scaled_max__proper_scaling(self):
x = scaled_array([-2.0, 2.0], 3, dtype=np.float32)
y = scaled_array([1.5, 1.5], 2, dtype=np.float32)
z = scaled_max(x, y)
assert isinstance(z, ScaledArray)
npt.assert_array_almost_equal(z, np.maximum(x, y))


class ScaledTranslationReducePrimitivesTests(chex.TestCase):
def setUp(self):
Expand Down

0 comments on commit 1ec9dd6

Please sign in to comment.