From 3ee4c08f3297fa1e5f9e0a696849a60fd0acfd9d Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 1 Dec 2023 11:30:42 +0000 Subject: [PATCH] Adding FLAX MNIST training example. Helps improving operation coverage in `autoscale` and finding edge cases. The training could be tested using: python ./experiments/mnist/flax_example/main.py --workdir ./experiments/mnist/flax_example/results --config ./experiments/mnist/flax_example/configs/default.py NOTE: We still need to understand why some operations now requires promotion to scaled array. Any idea why this is required? --- .gitignore | 3 + experiments/mnist/flax_example/README.md | 43 +++++ .../mnist/flax_example/configs/__init__.py | 0 .../mnist/flax_example/configs/default.py | 32 +++ experiments/mnist/flax_example/main.py | 62 ++++++ .../mnist/flax_example/requirements.txt | 5 + experiments/mnist/flax_example/train.py | 182 ++++++++++++++++++ jax_scaled_arithmetics/lax/scaled_ops.py | 20 +- 8 files changed, 343 insertions(+), 4 deletions(-) create mode 100644 experiments/mnist/flax_example/README.md create mode 100644 experiments/mnist/flax_example/configs/__init__.py create mode 100644 experiments/mnist/flax_example/configs/default.py create mode 100644 experiments/mnist/flax_example/main.py create mode 100644 experiments/mnist/flax_example/requirements.txt create mode 100644 experiments/mnist/flax_example/train.py diff --git a/.gitignore b/.gitignore index 07705cb..f8cc725 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,6 @@ dmypy.json # IDEs .vscode + +# ML tensorboard +*events.out* diff --git a/experiments/mnist/flax_example/README.md b/experiments/mnist/flax_example/README.md new file mode 100644 index 0000000..2a81692 --- /dev/null +++ b/experiments/mnist/flax_example/README.md @@ -0,0 +1,43 @@ +## MNIST classification + +Trains a simple convolutional network on the MNIST dataset. + +You can run this code and even modify it directly in Google Colab, no +installation required: + +https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mnist.ipynb + +### Requirements +* TensorFlow dataset `mnist` will be downloaded and prepared automatically, if necessary + +### Example output + +| Name | Epochs | Walltime | Top-1 accuracy | Metrics | Workdir | +| :------ | -----: | :------- | :------------- | :---------- | :---------------------------------------- | +| default | 10 | 7.7m | 99.17% | [tfhub.dev] | [gs://flax_public/examples/mnist/default] | + +[tfhub.dev]: https://tensorboard.dev/experiment/1G9SvrW5RQyojRtMKNmMuQ/#scalars&_smoothingWeight=0®exInput=default +[gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default + +``` +I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69 +I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14 +``` + +### How to run + +`python main.py --workdir=/tmp/mnist --config=configs/default.py` + +#### Overriding Hyperparameter configurations + +MNIST example allows specifying a hyperparameter configuration by the means of +setting `--config` flag. Configuration flag is defined using +[config_flags](https://github.com/google/ml_collections/tree/master#config-flags). +`config_flags` allows overriding configuration fields. This can be done as +follows: + +```shell +python main.py \ +--workdir=/tmp/mnist --config=configs/default.py \ +--config.learning_rate=0.05 --config.num_epochs=5 +``` diff --git a/experiments/mnist/flax_example/configs/__init__.py b/experiments/mnist/flax_example/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/mnist/flax_example/configs/default.py b/experiments/mnist/flax_example/configs/default.py new file mode 100644 index 0000000..5603612 --- /dev/null +++ b/experiments/mnist/flax_example/configs/default.py @@ -0,0 +1,32 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default Hyperparameter configuration.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.learning_rate = 0.1 + config.momentum = 0.9 + config.batch_size = 128 + config.num_epochs = 10 + return config + + +def metrics(): + return [] diff --git a/experiments/mnist/flax_example/main.py b/experiments/mnist/flax_example/main.py new file mode 100644 index 0000000..29cd53e --- /dev/null +++ b/experiments/mnist/flax_example/main.py @@ -0,0 +1,62 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main file for running the MNIST example. + +This file is intentionally kept short. The majority of logic is in libraries +than can be easily tested and imported in Colab. +""" + +import jax +import tensorflow as tf +import train +from absl import app, flags, logging +from clu import platform +from ml_collections import config_flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string("workdir", None, "Directory to store model data.") +config_flags.DEFINE_config_file( + "config", + None, + "File path to the training hyperparameter configuration.", + lock_config=True, +) + + +def main(argv): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make + # it unavailable to JAX. + tf.config.experimental.set_visible_devices([], "GPU") + + logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) + logging.info("JAX local devices: %r", jax.local_devices()) + + # Add a note so that we can tell which task is which JAX host. + # (Depending on the platform task 0 is not guaranteed to be host 0) + platform.work_unit().set_task_status( + f"process_index: {jax.process_index()}, " f"process_count: {jax.process_count()}" + ) + platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") + + train.train_and_evaluate(FLAGS.config, FLAGS.workdir) + + +if __name__ == "__main__": + flags.mark_flags_as_required(["config", "workdir"]) + app.run(main) diff --git a/experiments/mnist/flax_example/requirements.txt b/experiments/mnist/flax_example/requirements.txt new file mode 100644 index 0000000..fce9900 --- /dev/null +++ b/experiments/mnist/flax_example/requirements.txt @@ -0,0 +1,5 @@ +clu +flax +ml-collections +optax +tensorflow-datasets diff --git a/experiments/mnist/flax_example/train.py b/experiments/mnist/flax_example/train.py new file mode 100644 index 0000000..75a5dde --- /dev/null +++ b/experiments/mnist/flax_example/train.py @@ -0,0 +1,182 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MNIST example. + +Library file which executes the training and evaluation loop for MNIST. +The data is loaded using tensorflow_datasets. +""" + +# See issue #620. +# pytype: disable=wrong-keyword-args + +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np +import optax +import tensorflow_datasets as tfds +from absl import logging +from flax import linen as nn +from flax.metrics import tensorboard +from flax.training import train_state + +import jax_scaled_arithmetics as jsa + + +class CNN(nn.Module): + """A simple CNN model.""" + + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = nn.Dense(features=256)(x) + x = nn.relu(x) + x = nn.Dense(features=10)(x) + return x + + +@jax.jit +def apply_model(state, images, labels): + """Computes gradients, loss and accuracy for a single batch.""" + + def loss_fn(params): + logits = state.apply_fn({"params": params}, images) + one_hot = jax.nn.one_hot(labels, 10) + loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) + return loss, logits + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, logits), grads = grad_fn(state.params) + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) + return grads, loss, accuracy + + +@jax.jit +def update_model(state, grads): + return state.apply_gradients(grads=grads) + + +@jax.jit +@jsa.autoscale +def apply_and_update_model(state, batch_images, batch_labels): + # Jitting together forward + backward + update. + grads, loss, accuracy = apply_model(state, batch_images, batch_labels) + state = update_model(state, grads) + return state, loss, accuracy + + +def train_epoch(state, train_ds, batch_size, rng): + """Train for a single epoch.""" + train_ds_size = len(train_ds["image"]) + steps_per_epoch = train_ds_size // batch_size + + perms = jax.random.permutation(rng, len(train_ds["image"])) + perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch + perms = perms.reshape((steps_per_epoch, batch_size)) + + epoch_loss = [] + epoch_accuracy = [] + + for perm in perms: + batch_images = train_ds["image"][perm, ...] + batch_labels = train_ds["label"][perm, ...] + # Transform batch to ScaledArray + batch_images = jsa.as_scaled_array(batch_images) + # Apply & update stages in scaled mode. + state, loss, accuracy = apply_and_update_model(state, batch_images, batch_labels) + + epoch_loss.append(np.asarray(loss)) + epoch_accuracy.append(np.asarray(accuracy)) + + train_loss = np.mean(epoch_loss) + train_accuracy = np.mean(epoch_accuracy) + return state, train_loss, train_accuracy + + +def get_datasets(): + """Load MNIST train and test datasets into memory.""" + ds_builder = tfds.builder("mnist") + ds_builder.download_and_prepare() + train_ds = tfds.as_numpy(ds_builder.as_dataset(split="train", batch_size=-1)) + test_ds = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1)) + train_ds["image"] = jnp.float32(train_ds["image"]) / 255.0 + test_ds["image"] = jnp.float32(test_ds["image"]) / 255.0 + return train_ds, test_ds + + +def create_train_state(rng, config): + """Creates initial `TrainState`.""" + cnn = CNN() + params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))["params"] + tx = optax.sgd(config.learning_rate, config.momentum) + return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx) + + +def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train_state.TrainState: + """Execute model training and evaluation loop. + + Args: + config: Hyperparameter configuration for training and evaluation. + workdir: Directory where the tensorboard summaries are written to. + + Returns: + The train state (which includes the `.params`). + """ + train_ds, test_ds = get_datasets() + rng = jax.random.key(0) + + summary_writer = tensorboard.SummaryWriter(workdir) + summary_writer.hparams(dict(config)) + + rng, init_rng = jax.random.split(rng) + init_rng = jax.random.PRNGKey(1) + + state = create_train_state(init_rng, config) + # Convert model & optimizer states to `ScaledArray`` + state = jsa.as_scaled_array(state) + + logging.info("Start Flax MNIST training...") + + for epoch in range(1, config.num_epochs + 1): + rng, input_rng = jax.random.split(rng) + state, train_loss, train_accuracy = train_epoch(state, train_ds, config.batch_size, input_rng) + # NOTE: running evaluation on the plain normal arrays. + _, test_loss, test_accuracy = apply_model(jsa.asarray(state), test_ds["image"], test_ds["label"]) + + logging.info( + "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f," + " test_accuracy: %.2f" + % ( + epoch, + train_loss, + train_accuracy * 100, + test_loss, + test_accuracy * 100, + ) + ) + + summary_writer.scalar("train_loss", train_loss, epoch) + summary_writer.scalar("train_accuracy", train_accuracy, epoch) + summary_writer.scalar("test_loss", test_loss, epoch) + summary_writer.scalar("test_accuracy", test_accuracy, epoch) + + summary_writer.flush() + return state diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index 9f2f7fa..018b70d 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -9,7 +9,15 @@ from jax._src.ad_util import add_any_p from jax_scaled_arithmetics import core -from jax_scaled_arithmetics.core import Array, DTypeLike, ScaledArray, Shape, as_scaled_array, register_scaled_op +from jax_scaled_arithmetics.core import ( + Array, + DTypeLike, + ScaledArray, + Shape, + as_scaled_array, + is_static_zero, + register_scaled_op, +) from .base_scaling_primitives import scaled_set_scaling @@ -100,8 +108,8 @@ def scaled_rev(val: ScaledArray, dimensions: Sequence[int]) -> ScaledArray: @core.register_scaled_lax_op def scaled_pad(val: ScaledArray, padding_value: Any, padding_config: Any) -> ScaledArray: # Only supporting constant zero padding for now. - assert float(padding_value) == 0.0 - return ScaledArray(lax.pad(val.data, padding_value, padding_config), val.scale) + assert np.all(is_static_zero(padding_value)) + return ScaledArray(lax.pad(val.data, np.array(0, val.dtype), padding_config), val.scale) @core.register_scaled_lax_op @@ -128,19 +136,23 @@ def scaled_abs(val: ScaledArray) -> ScaledArray: @core.register_scaled_lax_op def scaled_mul(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray: - # TODO: understand why/when this conversion kicks in? + # TODO: understand when promotion is really required? lhs, rhs = as_scaled_array((lhs, rhs)) # type:ignore return ScaledArray(lhs.data * rhs.data, lhs.scale * rhs.scale) @core.register_scaled_lax_op def scaled_div(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray: + # TODO: understand when promotion is really required? + lhs, rhs = as_scaled_array((lhs, rhs)) # type:ignore # TODO: investigate different rule? return ScaledArray(lhs.data / rhs.data, lhs.scale / rhs.scale) @core.register_scaled_lax_op def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray: + # TODO: understand when promotion is really required? + A, B = as_scaled_array((A, B)) # type:ignore check_scalar_scales(A, B) A, B = promote_scale_types(A, B) assert np.issubdtype(A.scale.dtype, np.floating)