Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add epsilon parameter to dynamic rescale operations. #85

Merged
merged 1 commit into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions jax_scaled_arithmetics/ops/rescaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax
import numpy as np

from jax_scaled_arithmetics.core import ScaledArray, pow2_round
from jax_scaled_arithmetics.core import ScaledArray, pow2_round, pow2_round_down
from jax_scaled_arithmetics.lax import get_data_scale, rebalance


Expand Down Expand Up @@ -44,12 +44,15 @@ def fn_on_grad_bwd(f, _, grad):

def dynamic_rescale_max_base(arr: ScaledArray) -> ScaledArray:
"""Dynamic rescaling of a ScaledArray, using abs-max."""
# Similarly to ML norms => need some epsilon for training stability!
eps = pow2_round_down(np.float32(1e-4))

data, scale = get_data_scale(arr)
data_sq = jax.lax.abs(data)
axes = tuple(range(data.ndim))
# Get MAX norm + pow2 rounding.
norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes)
norm = pow2_round(norm.astype(scale.dtype))
norm = jax.lax.max(pow2_round(norm).astype(scale.dtype), eps.astype(scale.dtype))
# Rebalancing based on norm.
return rebalance(arr, norm)

Expand All @@ -59,12 +62,16 @@ def dynamic_rescale_l1_base(arr: ScaledArray) -> ScaledArray:

NOTE: by default, computing L1 norm in FP32.
"""
# Similarly to ML norms => need some epsilon for training stability!
norm_dtype = np.float32
eps = pow2_round_down(norm_dtype(1e-4))

data, scale = get_data_scale(arr)
data_sq = jax.lax.abs(data.astype(np.float32))
axes = tuple(range(data.ndim))
# Get L1 norm + pow2 rounding.
norm = jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size
norm = pow2_round(norm.astype(scale.dtype))
norm = jax.lax.max(pow2_round(norm), eps).astype(scale.dtype)
# Rebalancing based on norm.
return rebalance(arr, norm)

Expand All @@ -74,12 +81,17 @@ def dynamic_rescale_l2_base(arr: ScaledArray) -> ScaledArray:

NOTE: by default, computing L2 norm in FP32.
"""
# Similarly to ML norms => need some epsilon for training stability!
norm_dtype = np.float32
eps = pow2_round_down(norm_dtype(1e-4))

data, scale = get_data_scale(arr)
data_sq = jax.lax.integer_pow(data.astype(np.float32), 2)
data_sq = jax.lax.integer_pow(data.astype(norm_dtype), 2)
axes = tuple(range(data.ndim))
# Get L2 norm + pow2 rounding.
norm = jax.lax.sqrt(jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size)
norm = pow2_round(norm.astype(scale.dtype))
# Make sure we don't "underflow" too much on the norm.
norm = jax.lax.max(pow2_round(norm), eps).astype(scale.dtype)
# Rebalancing based on norm.
return rebalance(arr, norm)

Expand Down
13 changes: 13 additions & 0 deletions tests/ops/test_rescaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import chex
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from jax_scaled_arithmetics.core import ScaledArray, scaled_array
from jax_scaled_arithmetics.ops import dynamic_rescale_l1, dynamic_rescale_l2, dynamic_rescale_max
Expand Down Expand Up @@ -36,3 +37,15 @@ def test__dynamic_rescale_l2__proper_max_rescale_pow2_rounding(self):
assert arr_out.dtype == arr_in.dtype
npt.assert_array_equal(arr_out.scale, np.float16(16))
npt.assert_array_equal(arr_out, arr_in)

@parameterized.parameters(
{"dynamic_rescale_fn": dynamic_rescale_max},
{"dynamic_rescale_fn": dynamic_rescale_l1},
{"dynamic_rescale_fn": dynamic_rescale_l2},
)
def test__dynamic_rescale__epsilon_norm_value(self, dynamic_rescale_fn):
arr_in = scaled_array([0, 0], np.float32(1), dtype=np.float16)
arr_out = dynamic_rescale_fn(arr_in)
# Rough bounds on the epsilon value.
assert arr_out.scale > 0.0
assert arr_out.scale < 0.001