diff --git a/experiments/mnist/mnist_classifier_from_scratch.py b/experiments/mnist/mnist_classifier_from_scratch.py index 7e78cc6..5c5c719 100644 --- a/experiments/mnist/mnist_classifier_from_scratch.py +++ b/experiments/mnist/mnist_classifier_from_scratch.py @@ -25,11 +25,24 @@ import jax.numpy as jnp import numpy as np import numpy.random as npr -from jax import grad, jit -from jax.scipy.special import logsumexp +from jax import grad, jit, lax import jax_scaled_arithmetics as jsa +# from jax.scipy.special import logsumexp + + +def logsumexp(a, axis=None, keepdims=False): + dims = (axis,) + amax = jnp.max(a, axis=dims, keepdims=keepdims) + # FIXME: not proper scale propagation, introducing NaNs + # amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) + amax = lax.stop_gradient(amax) + out = lax.sub(a, amax) + out = lax.exp(out) + out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax) + return out + def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] @@ -46,7 +59,8 @@ def predict(params, inputs): logits = jnp.dot(activations, final_w) + final_b # Dynamic rescaling of the gradient, as logits gradient not properly scaled. logits = jsa.ops.dynamic_rescale_l2_grad(logits) - return logits - logsumexp(logits, axis=1, keepdims=True) + logits = logits - logsumexp(logits, axis=1, keepdims=True) + return logits def loss(params, batch): @@ -88,7 +102,7 @@ def data_stream(): batches = data_stream() params = init_random_params(param_scale, layer_sizes) # Transform parameters to `ScaledArray` and proper dtype. - params = jsa.as_scaled_array(params, scale=scale_dtype(1)) + params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) @jit diff --git a/jax_scaled_arithmetics/lax/scaled_ops_common.py b/jax_scaled_arithmetics/lax/scaled_ops_common.py index 91f0f5f..d9e6345 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops_common.py +++ b/jax_scaled_arithmetics/lax/scaled_ops_common.py @@ -261,12 +261,23 @@ def scaled_op_default_translation( @core.register_scaled_lax_op def scaled_exp(val: ScaledArray) -> ScaledArray: - return scaled_op_default_translation(lax.exp_p, [val]) + assert isinstance(val, ScaledArray) + # Estimate in FP32, to avoid NaN when "descaling" the array. + # Otherwise: issues for representing properly 0 and +-Inf. + arr = val.to_array(dtype=np.float32).astype(val.dtype) + scale = np.array(1, dtype=val.scale.dtype) + return ScaledArray(lax.exp(arr), scale) @core.register_scaled_lax_op def scaled_log(val: ScaledArray) -> ScaledArray: - return scaled_op_default_translation(lax.log_p, [val]) + assert isinstance(val, ScaledArray) + # Log of data & scale components. + log_data = lax.log(val.data) + log_scale = lax.log(val.scale).astype(val.dtype) + data = log_data + log_scale + scale = np.array(1, dtype=val.scale.dtype) + return ScaledArray(data, scale) @core.register_scaled_lax_op diff --git a/tests/lax/test_scaled_ops_l2.py b/tests/lax/test_scaled_ops_l2.py index 3d9bf90..85c2600 100644 --- a/tests/lax/test_scaled_ops_l2.py +++ b/tests/lax/test_scaled_ops_l2.py @@ -61,7 +61,24 @@ def test__scaled_unary_op__proper_result_and_scaling(self, prim, dtype, expected assert out.dtype == val.dtype assert out.scale.dtype == val.scale.dtype npt.assert_almost_equal(out.scale, expected_scale) - npt.assert_array_almost_equal(out, expected_output) + # FIXME: higher precision for `log`? + npt.assert_array_almost_equal(out, expected_output, decimal=3) + + def test__scaled_exp__large_scale_zero_values(self): + scaled_op, _ = find_registered_scaled_op(lax.exp_p) + # Scaled array, with values < 0 and scale overflowing in float16. + val = scaled_array(np.array([0, -1, -2, -32768], np.float16), np.float32(32768 * 16)) + out = scaled_op(val) + # Zero value should not be a NaN! + npt.assert_array_almost_equal(out, [1, 0, 0, 0], decimal=2) + + def test__scaled_log__zero_large_values_large_scale(self): + scaled_op, _ = find_registered_scaled_op(lax.log_p) + # 0 + large values => proper log values, without NaN/overflow. + val = scaled_array(np.array([0, 1], np.float16), np.float32(32768 * 16)) + out = scaled_op(val) + # No NaN value + not overflowing! + npt.assert_array_almost_equal(out, lax.log(val.to_array(np.float32)), decimal=2) class ScaledTranslationBinaryOpsTests(chex.TestCase): diff --git a/tests/lax/test_scipy_integration.py b/tests/lax/test_scipy_integration.py index 1269af0..b0d97ae 100644 --- a/tests/lax/test_scipy_integration.py +++ b/tests/lax/test_scipy_integration.py @@ -3,16 +3,25 @@ import numpy as np import numpy.testing as npt from absl.testing import parameterized +from jax import lax from jax_scaled_arithmetics.core import autoscale, scaled_array -class ScaledTranslationPrimitivesTests(chex.TestCase): +class ScaledScipyHighLevelMethodsTests(chex.TestCase): def setUp(self): super().setUp() # Use random state for reproducibility! self.rs = np.random.RandomState(42) + def test__lax_full_like__zero_scale(self): + def fn(a): + return lax.full_like(a, 0) + + a = scaled_array(np.random.rand(3, 5).astype(np.float32), np.float32(1)) + autoscale(fn)(a) + # FIMXE/TODO: what should be the expected result? + @parameterized.parameters( {"dtype": np.float32}, {"dtype": np.float16},