From a870889a977f8993a876b4ed476411098202a43d Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 1 Dec 2023 11:30:42 +0000 Subject: [PATCH] Adding FLAX MNIST training example. Helps improving operation coverage in `autoscale`. --- .gitignore | 3 + experiments/mnist/flax/README.md | 43 ++++++ experiments/mnist/flax/configs/__init__.py | 0 experiments/mnist/flax/configs/default.py | 32 ++++ experiments/mnist/flax/main.py | 62 ++++++++ experiments/mnist/flax/mnist_benchmark.py | 82 +++++++++++ experiments/mnist/flax/requirements.txt | 5 + experiments/mnist/flax/train.py | 161 +++++++++++++++++++++ experiments/mnist/flax/train_test.py | 71 +++++++++ 9 files changed, 459 insertions(+) create mode 100644 experiments/mnist/flax/README.md create mode 100644 experiments/mnist/flax/configs/__init__.py create mode 100644 experiments/mnist/flax/configs/default.py create mode 100644 experiments/mnist/flax/main.py create mode 100644 experiments/mnist/flax/mnist_benchmark.py create mode 100644 experiments/mnist/flax/requirements.txt create mode 100644 experiments/mnist/flax/train.py create mode 100644 experiments/mnist/flax/train_test.py diff --git a/.gitignore b/.gitignore index 07705cb..f8cc725 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,6 @@ dmypy.json # IDEs .vscode + +# ML tensorboard +*events.out* diff --git a/experiments/mnist/flax/README.md b/experiments/mnist/flax/README.md new file mode 100644 index 0000000..2a81692 --- /dev/null +++ b/experiments/mnist/flax/README.md @@ -0,0 +1,43 @@ +## MNIST classification + +Trains a simple convolutional network on the MNIST dataset. + +You can run this code and even modify it directly in Google Colab, no +installation required: + +https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mnist.ipynb + +### Requirements +* TensorFlow dataset `mnist` will be downloaded and prepared automatically, if necessary + +### Example output + +| Name | Epochs | Walltime | Top-1 accuracy | Metrics | Workdir | +| :------ | -----: | :------- | :------------- | :---------- | :---------------------------------------- | +| default | 10 | 7.7m | 99.17% | [tfhub.dev] | [gs://flax_public/examples/mnist/default] | + +[tfhub.dev]: https://tensorboard.dev/experiment/1G9SvrW5RQyojRtMKNmMuQ/#scalars&_smoothingWeight=0®exInput=default +[gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default + +``` +I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69 +I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14 +``` + +### How to run + +`python main.py --workdir=/tmp/mnist --config=configs/default.py` + +#### Overriding Hyperparameter configurations + +MNIST example allows specifying a hyperparameter configuration by the means of +setting `--config` flag. Configuration flag is defined using +[config_flags](https://github.com/google/ml_collections/tree/master#config-flags). +`config_flags` allows overriding configuration fields. This can be done as +follows: + +```shell +python main.py \ +--workdir=/tmp/mnist --config=configs/default.py \ +--config.learning_rate=0.05 --config.num_epochs=5 +``` diff --git a/experiments/mnist/flax/configs/__init__.py b/experiments/mnist/flax/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/mnist/flax/configs/default.py b/experiments/mnist/flax/configs/default.py new file mode 100644 index 0000000..5603612 --- /dev/null +++ b/experiments/mnist/flax/configs/default.py @@ -0,0 +1,32 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default Hyperparameter configuration.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.learning_rate = 0.1 + config.momentum = 0.9 + config.batch_size = 128 + config.num_epochs = 10 + return config + + +def metrics(): + return [] diff --git a/experiments/mnist/flax/main.py b/experiments/mnist/flax/main.py new file mode 100644 index 0000000..29cd53e --- /dev/null +++ b/experiments/mnist/flax/main.py @@ -0,0 +1,62 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main file for running the MNIST example. + +This file is intentionally kept short. The majority of logic is in libraries +than can be easily tested and imported in Colab. +""" + +import jax +import tensorflow as tf +import train +from absl import app, flags, logging +from clu import platform +from ml_collections import config_flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string("workdir", None, "Directory to store model data.") +config_flags.DEFINE_config_file( + "config", + None, + "File path to the training hyperparameter configuration.", + lock_config=True, +) + + +def main(argv): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make + # it unavailable to JAX. + tf.config.experimental.set_visible_devices([], "GPU") + + logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) + logging.info("JAX local devices: %r", jax.local_devices()) + + # Add a note so that we can tell which task is which JAX host. + # (Depending on the platform task 0 is not guaranteed to be host 0) + platform.work_unit().set_task_status( + f"process_index: {jax.process_index()}, " f"process_count: {jax.process_count()}" + ) + platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") + + train.train_and_evaluate(FLAGS.config, FLAGS.workdir) + + +if __name__ == "__main__": + flags.mark_flags_as_required(["config", "workdir"]) + app.run(main) diff --git a/experiments/mnist/flax/mnist_benchmark.py b/experiments/mnist/flax/mnist_benchmark.py new file mode 100644 index 0000000..f21da05 --- /dev/null +++ b/experiments/mnist/flax/mnist_benchmark.py @@ -0,0 +1,82 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark for the MNIST example.""" +import time + +import jax +import main +import numpy as np +from absl import flags +from absl.testing import absltest +from absl.testing.flagsaver import flagsaver +from configs import default +from flax.testing import Benchmark + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + +FLAGS = flags.FLAGS + + +class MnistBenchmark(Benchmark): + """Benchmarks for the MNIST Flax example.""" + + @flagsaver + def test_cpu(self): + """Run full training for MNIST CPU training.""" + # Prepare and set flags defined in main.py. + workdir = self.get_tmp_model_dir() + config = default.get_config() + + FLAGS.workdir = workdir + FLAGS.config = config + + start_time = time.time() + main.main([]) + benchmark_time = time.time() - start_time + + summaries = self.read_summaries(workdir) + + # Summaries contain all the information necessary for the regression + # metrics. + wall_time, _, eval_accuracy = zip(*summaries["eval_accuracy"]) + wall_time = np.array(wall_time) + sec_per_epoch = np.mean(wall_time[1:] - wall_time[:-1]) + end_eval_accuracy = eval_accuracy[-1] + + # Assertions are deferred until the test finishes, so the metrics are + # always reported and benchmark success is determined based on *all* + # assertions. + self.assertBetween(end_eval_accuracy, 0.98, 1.0) + + # Use the reporting API to report single or multiple metrics/extras. + self.report_wall_time(benchmark_time) + self.report_metrics( + { + "sec_per_epoch": sec_per_epoch, + "accuracy": end_eval_accuracy, + } + ) + self.report_extras( + { + "model_name": "MNIST", + "description": "CPU test for MNIST.", + "implementation": "linen", + } + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/experiments/mnist/flax/requirements.txt b/experiments/mnist/flax/requirements.txt new file mode 100644 index 0000000..fce9900 --- /dev/null +++ b/experiments/mnist/flax/requirements.txt @@ -0,0 +1,5 @@ +clu +flax +ml-collections +optax +tensorflow-datasets diff --git a/experiments/mnist/flax/train.py b/experiments/mnist/flax/train.py new file mode 100644 index 0000000..f794223 --- /dev/null +++ b/experiments/mnist/flax/train.py @@ -0,0 +1,161 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MNIST example. + +Library file which executes the training and evaluation loop for MNIST. +The data is loaded using tensorflow_datasets. +""" + +# See issue #620. +# pytype: disable=wrong-keyword-args + +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np +import optax +import tensorflow_datasets as tfds +from absl import logging +from flax import linen as nn +from flax.metrics import tensorboard +from flax.training import train_state + + +class CNN(nn.Module): + """A simple CNN model.""" + + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = nn.Dense(features=256)(x) + x = nn.relu(x) + x = nn.Dense(features=10)(x) + return x + + +@jax.jit +def apply_model(state, images, labels): + """Computes gradients, loss and accuracy for a single batch.""" + + def loss_fn(params): + logits = state.apply_fn({"params": params}, images) + one_hot = jax.nn.one_hot(labels, 10) + loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) + return loss, logits + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, logits), grads = grad_fn(state.params) + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) + return grads, loss, accuracy + + +@jax.jit +def update_model(state, grads): + return state.apply_gradients(grads=grads) + + +def train_epoch(state, train_ds, batch_size, rng): + """Train for a single epoch.""" + train_ds_size = len(train_ds["image"]) + steps_per_epoch = train_ds_size // batch_size + + perms = jax.random.permutation(rng, len(train_ds["image"])) + perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch + perms = perms.reshape((steps_per_epoch, batch_size)) + + epoch_loss = [] + epoch_accuracy = [] + + for perm in perms: + batch_images = train_ds["image"][perm, ...] + batch_labels = train_ds["label"][perm, ...] + grads, loss, accuracy = apply_model(state, batch_images, batch_labels) + state = update_model(state, grads) + epoch_loss.append(loss) + epoch_accuracy.append(accuracy) + train_loss = np.mean(epoch_loss) + train_accuracy = np.mean(epoch_accuracy) + return state, train_loss, train_accuracy + + +def get_datasets(): + """Load MNIST train and test datasets into memory.""" + ds_builder = tfds.builder("mnist") + ds_builder.download_and_prepare() + train_ds = tfds.as_numpy(ds_builder.as_dataset(split="train", batch_size=-1)) + test_ds = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1)) + train_ds["image"] = jnp.float32(train_ds["image"]) / 255.0 + test_ds["image"] = jnp.float32(test_ds["image"]) / 255.0 + return train_ds, test_ds + + +def create_train_state(rng, config): + """Creates initial `TrainState`.""" + cnn = CNN() + params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))["params"] + tx = optax.sgd(config.learning_rate, config.momentum) + return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx) + + +def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train_state.TrainState: + """Execute model training and evaluation loop. + + Args: + config: Hyperparameter configuration for training and evaluation. + workdir: Directory where the tensorboard summaries are written to. + + Returns: + The train state (which includes the `.params`). + """ + train_ds, test_ds = get_datasets() + rng = jax.random.key(0) + + summary_writer = tensorboard.SummaryWriter(workdir) + summary_writer.hparams(dict(config)) + + rng, init_rng = jax.random.split(rng) + init_rng = jax.random.PRNGKey(1) + state = create_train_state(init_rng, config) + + for epoch in range(1, config.num_epochs + 1): + rng, input_rng = jax.random.split(rng) + state, train_loss, train_accuracy = train_epoch(state, train_ds, config.batch_size, input_rng) + _, test_loss, test_accuracy = apply_model(state, test_ds["image"], test_ds["label"]) + + logging.info( + "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f," + " test_accuracy: %.2f" + % ( + epoch, + train_loss, + train_accuracy * 100, + test_loss, + test_accuracy * 100, + ) + ) + + summary_writer.scalar("train_loss", train_loss, epoch) + summary_writer.scalar("train_accuracy", train_accuracy, epoch) + summary_writer.scalar("test_loss", test_loss, epoch) + summary_writer.scalar("test_accuracy", test_accuracy, epoch) + + summary_writer.flush() + return state diff --git a/experiments/mnist/flax/train_test.py b/experiments/mnist/flax/train_test.py new file mode 100644 index 0000000..ce11b51 --- /dev/null +++ b/experiments/mnist/flax/train_test.py @@ -0,0 +1,71 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for flax.examples.mnist.mnist_lib.""" + +import pathlib +import tempfile + +import jax +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +import train +from absl.testing import absltest +from configs import default +from jax import numpy as jnp + +CNN_PARAMS = 825_034 + + +class TrainTest(absltest.TestCase): + """Test cases for train.""" + + def setUp(self): + super().setUp() + # Make sure tf does not allocate gpu memory. + tf.config.experimental.set_visible_devices([], "GPU") + + def test_cnn(self): + """Tests CNN module used as the trainable model.""" + rng = jax.random.key(0) + inputs = jnp.ones((1, 28, 28, 3), jnp.float32) + output, variables = train.CNN().init_with_output(rng, inputs) + + self.assertEqual((1, 10), output.shape) + self.assertEqual( + CNN_PARAMS, + sum(np.prod(arr.shape) for arr in jax.tree_util.tree_leaves(variables["params"])), + ) + + def test_train_and_evaluate(self): + """Tests training and evaluation code by running a single step.""" + # Create a temporary directory where tensorboard metrics are written. + workdir = tempfile.mkdtemp() + + # Go two directories up to the root of the flax directory. + flax_root_dir = pathlib.Path(__file__).parents[2] + data_dir = str(flax_root_dir) + "/.tfds/metadata" # pylint: disable=unused-variable + + # Define training configuration. + config = default.get_config() + config.num_epochs = 1 + config.batch_size = 8 + + with tfds.testing.mock_data(num_examples=8, data_dir=data_dir): + train.train_and_evaluate(config=config, workdir=workdir) + + +if __name__ == "__main__": + absltest.main()