Skip to content

Commit

Permalink
Improving robustness of log and exp, with proper special values o…
Browse files Browse the repository at this point in the history
…utput. (#84)

Making sure that `exp` of `0` is `1` and `log` of `0` is `-inf`.
Using a custom `logsumexp` in MNIST example until an additional scale propagation
bug is solved.

NOTE: additional robustness means MNIST training converges when initialization scale > 1.
  • Loading branch information
balancap authored Jan 15, 2024
1 parent ebfc951 commit 63ef3ba
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 8 deletions.
22 changes: 18 additions & 4 deletions experiments/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,24 @@
import jax.numpy as jnp
import numpy as np
import numpy.random as npr
from jax import grad, jit
from jax.scipy.special import logsumexp
from jax import grad, jit, lax

import jax_scaled_arithmetics as jsa

# from jax.scipy.special import logsumexp


def logsumexp(a, axis=None, keepdims=False):
dims = (axis,)
amax = jnp.max(a, axis=dims, keepdims=keepdims)
# FIXME: not proper scale propagation, introducing NaNs
# amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
amax = lax.stop_gradient(amax)
out = lax.sub(a, amax)
out = lax.exp(out)
out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax)
return out


def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
Expand All @@ -46,7 +59,8 @@ def predict(params, inputs):
logits = jnp.dot(activations, final_w) + final_b
# Dynamic rescaling of the gradient, as logits gradient not properly scaled.
logits = jsa.ops.dynamic_rescale_l2_grad(logits)
return logits - logsumexp(logits, axis=1, keepdims=True)
logits = logits - logsumexp(logits, axis=1, keepdims=True)
return logits


def loss(params, batch):
Expand Down Expand Up @@ -88,7 +102,7 @@ def data_stream():
batches = data_stream()
params = init_random_params(param_scale, layer_sizes)
# Transform parameters to `ScaledArray` and proper dtype.
params = jsa.as_scaled_array(params, scale=scale_dtype(1))
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)

@jit
Expand Down
15 changes: 13 additions & 2 deletions jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,23 @@ def scaled_op_default_translation(

@core.register_scaled_lax_op
def scaled_exp(val: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.exp_p, [val])
assert isinstance(val, ScaledArray)
# Estimate in FP32, to avoid NaN when "descaling" the array.
# Otherwise: issues for representing properly 0 and +-Inf.
arr = val.to_array(dtype=np.float32).astype(val.dtype)
scale = np.array(1, dtype=val.scale.dtype)
return ScaledArray(lax.exp(arr), scale)


@core.register_scaled_lax_op
def scaled_log(val: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.log_p, [val])
assert isinstance(val, ScaledArray)
# Log of data & scale components.
log_data = lax.log(val.data)
log_scale = lax.log(val.scale).astype(val.dtype)
data = log_data + log_scale
scale = np.array(1, dtype=val.scale.dtype)
return ScaledArray(data, scale)


@core.register_scaled_lax_op
Expand Down
19 changes: 18 additions & 1 deletion tests/lax/test_scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,24 @@ def test__scaled_unary_op__proper_result_and_scaling(self, prim, dtype, expected
assert out.dtype == val.dtype
assert out.scale.dtype == val.scale.dtype
npt.assert_almost_equal(out.scale, expected_scale)
npt.assert_array_almost_equal(out, expected_output)
# FIXME: higher precision for `log`?
npt.assert_array_almost_equal(out, expected_output, decimal=3)

def test__scaled_exp__large_scale_zero_values(self):
scaled_op, _ = find_registered_scaled_op(lax.exp_p)
# Scaled array, with values < 0 and scale overflowing in float16.
val = scaled_array(np.array([0, -1, -2, -32768], np.float16), np.float32(32768 * 16))
out = scaled_op(val)
# Zero value should not be a NaN!
npt.assert_array_almost_equal(out, [1, 0, 0, 0], decimal=2)

def test__scaled_log__zero_large_values_large_scale(self):
scaled_op, _ = find_registered_scaled_op(lax.log_p)
# 0 + large values => proper log values, without NaN/overflow.
val = scaled_array(np.array([0, 1], np.float16), np.float32(32768 * 16))
out = scaled_op(val)
# No NaN value + not overflowing!
npt.assert_array_almost_equal(out, lax.log(val.to_array(np.float32)), decimal=2)


class ScaledTranslationBinaryOpsTests(chex.TestCase):
Expand Down
11 changes: 10 additions & 1 deletion tests/lax/test_scipy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,25 @@
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized
from jax import lax

from jax_scaled_arithmetics.core import autoscale, scaled_array


class ScaledTranslationPrimitivesTests(chex.TestCase):
class ScaledScipyHighLevelMethodsTests(chex.TestCase):
def setUp(self):
super().setUp()
# Use random state for reproducibility!
self.rs = np.random.RandomState(42)

def test__lax_full_like__zero_scale(self):
def fn(a):
return lax.full_like(a, 0)

a = scaled_array(np.random.rand(3, 5).astype(np.float32), np.float32(1))
autoscale(fn)(a)
# FIMXE/TODO: what should be the expected result?

@parameterized.parameters(
{"dtype": np.float32},
{"dtype": np.float16},
Expand Down

0 comments on commit 63ef3ba

Please sign in to comment.