Skip to content

Commit

Permalink
Improve numerical stability of scaled add/sub translations. (#58)
Browse files Browse the repository at this point in the history
Current implementation can easily overflow due to the square of scales being used
for the L2 norm.

min/max rescaling trick, as underflowing to zero of the ratio is preferable
to overflowing the square of the scale.
  • Loading branch information
balancap authored Dec 19, 2023
1 parent 235806a commit ea122af
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
32 changes: 14 additions & 18 deletions jax_scaled_arithmetics/lax/scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,44 +12,40 @@
from .scaled_ops_common import check_scalar_scales, promote_scale_types


@core.register_scaled_lax_op
def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray:
def scaled_add_sub(A: ScaledArray, B: ScaledArray, binary_op: Any) -> ScaledArray:
"""Scaled add/sub generic implementation."""
# TODO: understand when promotion is really required?
A, B = as_scaled_array((A, B)) # type:ignore
check_scalar_scales(A, B)
A, B = promote_scale_types(A, B)
assert np.issubdtype(A.scale.dtype, np.floating)
# TODO: what happens to `sqrt` for non-floating scale?
output_scale = lax.sqrt(A.scale * A.scale + B.scale * B.scale)
# More stable than direct L2 norm, to avoid scale overflow.
ABscale_max = lax.max(A.scale, B.scale)
ABscale_min = lax.min(A.scale, B.scale)
ABscale_ratio = ABscale_min / ABscale_max
output_scale = ABscale_max * lax.sqrt(1 + ABscale_ratio * ABscale_ratio)
# Output dtype => promotion of A and B dtypes.
outdtype = jnp.promote_types(A.dtype, B.dtype)
Arescale = (A.scale / output_scale).astype(outdtype)
Brescale = (B.scale / output_scale).astype(outdtype)
# check correct type output if mismatch between data and scale precision
output_data = Arescale * A.data + Brescale * B.data
output_data = binary_op(Arescale * A.data, Brescale * B.data)
return ScaledArray(output_data, output_scale)


@core.register_scaled_lax_op
def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return scaled_add_sub(A, B, lax.add)


# TODO: understand difference between `add` and `add_anys`
register_scaled_op(add_any_p, scaled_add)


@core.register_scaled_lax_op
def scaled_sub(A: ScaledArray, B: ScaledArray) -> ScaledArray:
# TODO: understand when promotion is really required?
A, B = as_scaled_array((A, B)) # type:ignore
check_scalar_scales(A, B)
A, B = promote_scale_types(A, B)
assert np.issubdtype(A.scale.dtype, np.floating)
# TODO: what happens to `sqrt` for non-floating scale?
output_scale = lax.sqrt(A.scale * A.scale + B.scale * B.scale)
# Output dtype => promotion of A and B dtypes.
outdtype = jnp.promote_types(A.dtype, B.dtype)
Arescale = (A.scale / output_scale).astype(outdtype)
Brescale = (B.scale / output_scale).astype(outdtype)
# check correct type output if mismatch between data and scale precision
output_data = Arescale * A.data - Brescale * B.data
return ScaledArray(output_data, output_scale)
return scaled_add_sub(A, B, lax.sub)


@core.register_scaled_lax_op
Expand Down
14 changes: 14 additions & 0 deletions tests/lax/test_scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,20 @@ def test__scaled_addsub__proper_scaling(self, prim):
assert z.dtype == x.dtype
npt.assert_almost_equal(z.scale, np.sqrt(4.0 + 9.0))

@parameterized.parameters(
{"prim": lax.add_p},
{"prim": lax.sub_p},
)
def test__scaled_addsub__not_overflowing_scale(self, prim):
scaled_op, _ = find_registered_scaled_op(prim)
x = scaled_array([-1.0, 2.0], np.float16(2.0), dtype=np.float16)
y = scaled_array([1.5, 4.0], np.float16(1024.0), dtype=np.float16)
z = scaled_op(x, y)
print(z, x, y)
assert z.scale.dtype == np.float16
assert np.isfinite(z.scale)
npt.assert_array_almost_equal(z, prim.bind(np.asarray(x, np.float32), np.asarray(y, np.float32)), decimal=6)

def test__scaled_mul__proper_scaling(self):
x = scaled_array([-2.0, 2.0], 3, dtype=np.float32)
y = scaled_array([1.5, 1.5], 2, dtype=np.float32)
Expand Down

0 comments on commit ea122af

Please sign in to comment.