diff --git a/experiments/mnist/optax_cifar_training.py b/experiments/mnist/optax_cifar_training.py new file mode 100644 index 0000000..dab0e59 --- /dev/null +++ b/experiments/mnist/optax_cifar_training.py @@ -0,0 +1,159 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A basic MNIST example using Numpy and JAX. + +The primary aim here is simplicity and minimal dependencies. +""" + + +import time + +import datasets +import jax +import jax.numpy as jnp +import numpy as np +import numpy.random as npr +import optax +from jax import grad, jit, lax + +import jax_scaled_arithmetics as jsa + + +def logsumexp(a, axis=None, keepdims=False): + dims = (axis,) + amax = jnp.max(a, axis=dims, keepdims=keepdims) + # FIXME: not proper scale propagation, introducing NaNs + # amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) + amax = lax.stop_gradient(amax) + out = lax.sub(a, amax) + out = lax.exp(out) + out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax) + return out + + +def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): + return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] + + +def print_mean_std(name, v): + data, scale = jsa.lax.get_data_scale(v) + # Always use np.float32, to avoid floating errors in descaling + stats. + v = jsa.asarray(data, dtype=np.float32) + m, s = np.mean(v), np.std(v) + # print(data) + print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})") + + +def predict(params, inputs): + activations = inputs + for w, b in params[:-1]: + # Matmul + relu + outputs = jnp.dot(activations, w) + b + activations = jnp.maximum(outputs, 0) + + final_w, final_b = params[-1] + logits = jnp.dot(activations, final_w) + final_b + + # jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits) + # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) + + # Dynamic rescaling of the gradient, as logits gradient not properly scaled. + logits = jsa.ops.dynamic_rescale_l2_grad(logits) + output = logits - logsumexp(logits, axis=1, keepdims=True) + + return output + + +def loss(params, batch): + inputs, targets = batch + preds = predict(params, inputs) + return -jnp.mean(jnp.sum(preds * targets, axis=1)) + + +def accuracy(params, batch): + inputs, targets = batch + target_class = jnp.argmax(targets, axis=1) + predicted_class = jnp.argmax(predict(params, inputs), axis=1) + return jnp.mean(predicted_class == target_class) + + +if __name__ == "__main__": + width = 256 + lr = 1e-3 + use_autoscale = False + training_dtype = np.float32 + autoscale = jsa.autoscale if use_autoscale else lambda f: f + + layer_sizes = [3072, width, width, 10] + param_scale = 1.0 + num_epochs = 10 + batch_size = 128 + scale_dtype = np.float32 + + train_images, train_labels, test_images, test_labels = datasets.cifar() + num_train = train_images.shape[0] + num_complete_batches, leftover = divmod(num_train, batch_size) + num_batches = num_complete_batches + bool(leftover) + # num_batches = 2 + + def data_stream(): + rng = npr.RandomState(0) + while True: + perm = rng.permutation(num_train) + for i in range(num_batches): + batch_idx = perm[i * batch_size : (i + 1) * batch_size] + yield train_images[batch_idx], train_labels[batch_idx] + + batches = data_stream() + params = init_random_params(param_scale, layer_sizes) + params = jax.tree_map(lambda v: v.astype(training_dtype), params) + # Transform parameters to `ScaledArray` and proper dtype. + optimizer = optax.adam(learning_rate=lr, eps=1e-5) + opt_state = optimizer.init(params) + + if use_autoscale: + params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) + + params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) + + @jit + @autoscale + def update(params, batch, opt_state): + grads = grad(loss)(params, batch) + updates, opt_state = optimizer.update(grads, opt_state) + params = optax.apply_updates(params, updates) + return params, opt_state + + for epoch in range(num_epochs): + start_time = time.time() + for _ in range(num_batches): + batch = next(batches) + # Scaled micro-batch + training dtype cast. + if use_autoscale: + batch = jsa.as_scaled_array(batch, scale=scale_dtype(param_scale)) + batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) + + with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): + params, opt_state = update(params, batch, opt_state) + + epoch_time = time.time() - start_time + + # Evaluation in float32, for consistency. + raw_params = jsa.asarray(params, dtype=np.float32) + train_acc = accuracy(raw_params, (train_images, train_labels)) + test_acc = accuracy(raw_params, (test_images, test_labels)) + print(f"Epoch {epoch} in {epoch_time:0.2f} sec") + print(f"Training set accuracy {train_acc:0.5f}") + print(f"Test set accuracy {test_acc:0.5f}") diff --git a/jax_scaled_arithmetics/lax/scaled_ops_common.py b/jax_scaled_arithmetics/lax/scaled_ops_common.py index 49799fb..b58ffd9 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops_common.py +++ b/jax_scaled_arithmetics/lax/scaled_ops_common.py @@ -316,3 +316,17 @@ def scaled_cos(val: ScaledArray) -> ScaledArray: @core.register_scaled_lax_op def scaled_sin(val: ScaledArray) -> ScaledArray: return scaled_op_default_translation(lax.sin_p, [val]) + + +@core.register_scaled_lax_op +def scaled_integer_pow(A: ScaledArray, y: int) -> ScaledArray: + output_scale = lax.integer_pow(A.scale, y) + output_data = lax.integer_pow(A.data, y) + return ScaledArray(output_data, output_scale) + + +@core.register_scaled_lax_op +def scaled_sqrt(val: ScaledArray) -> ScaledArray: + output_scale = lax.sqrt(val.scale) + output_data = lax.sqrt(val.data) + return ScaledArray(output_data, output_scale)