Skip to content

Commit

Permalink
Implement reduce_sum/prod/max/min scaled translation rules.
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Nov 21, 2023
1 parent 642a58e commit 5aefb4a
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
38 changes: 38 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,41 @@ def scaled_dot_general(
/ contracting_rescale
)
return ScaledArray(output_data, output_scale)


@core.register_scaled_lax_op
def scaled_reduce_sum(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
shape = val.shape
axes_size = np.array([shape[idx] for idx in axes])
# Rescale data component following reduction axes.
axes_rescale = np.sqrt(np.prod(axes_size))
data = lax.reduce_sum_p.bind(val.data, axes=axes) / axes_rescale
outscale = val.scale * axes_rescale
return ScaledArray(data, outscale)


@core.register_scaled_lax_op
def scaled_reduce_prod(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
shape = val.shape
data = lax.reduce_prod_p.bind(val.data, axes=axes)
axes_size = np.prod(np.array([shape[idx] for idx in axes]))
scale = lax.integer_pow(val.scale, axes_size)
return ScaledArray(data, scale)


@core.register_scaled_lax_op
def scaled_reduce_max(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
data = lax.reduce_max_p.bind(val.data, axes=axes)
# unchanged scaling.
return ScaledArray(data, val.scale)


@core.register_scaled_lax_op
def scaled_reduce_min(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
data = lax.reduce_min_p.bind(val.data, axes=axes)
# unchanged scaling.
return ScaledArray(data, val.scale)
40 changes: 40 additions & 0 deletions tests/lax/test_scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
scaled_convert_element_type,
scaled_dot_general,
scaled_mul,
scaled_reduce_max,
scaled_reduce_min,
scaled_reduce_prod,
scaled_reduce_sum,
scaled_slice,
scaled_sub,
scaled_transpose,
Expand Down Expand Up @@ -93,3 +97,39 @@ def test__scaled_dot_general__proper_scaling(self):
assert out.dtype == lhs.dtype
npt.assert_almost_equal(out.scale, lhs.scale * rhs.scale * np.sqrt(5))
npt.assert_array_almost_equal(out, np.asarray(lhs) @ np.asarray(rhs))

def test__scaled_reduce_sum__proper_scaling(self):
val = scaled_array(self.rs.rand(5), 2.0, dtype=np.float32)
out = scaled_reduce_sum(val, axes=(0,))
assert isinstance(out, ScaledArray)
assert out.shape == ()
assert out.dtype == val.dtype
npt.assert_almost_equal(out.scale, val.scale * np.sqrt(5))
npt.assert_array_almost_equal(out, np.sum(val))

def test__scaled_reduce_prod__proper_scaling(self):
val = scaled_array(self.rs.rand(5), 2.0, dtype=np.float32)
out = scaled_reduce_prod(val, axes=(0,))
assert isinstance(out, ScaledArray)
assert out.shape == ()
assert out.dtype == val.dtype
npt.assert_almost_equal(out.scale, 2**5)
npt.assert_array_almost_equal(out, np.prod(val))

def test__scaled_reduce_min__proper_scaling(self):
val = scaled_array(self.rs.rand(5), 2.0, dtype=np.float32)
out = scaled_reduce_min(val, axes=(0,))
assert isinstance(out, ScaledArray)
assert out.shape == ()
assert out.dtype == val.dtype
npt.assert_almost_equal(out.scale, 2)
npt.assert_array_almost_equal(out, np.min(val))

def test__scaled_reduce_max__proper_scaling(self):
val = scaled_array(self.rs.rand(5), 2.0, dtype=np.float32)
out = scaled_reduce_max(val, axes=(0,))
assert isinstance(out, ScaledArray)
assert out.shape == ()
assert out.dtype == val.dtype
npt.assert_almost_equal(out.scale, 2)
npt.assert_array_almost_equal(out, np.max(val))

0 comments on commit 5aefb4a

Please sign in to comment.