Skip to content

Commit

Permalink
Adding FLAX MNIST training example.
Browse files Browse the repository at this point in the history
Helps improving operation coverage in `autoscale`.
  • Loading branch information
balancap committed Dec 4, 2023
1 parent 5ae6190 commit a870889
Show file tree
Hide file tree
Showing 9 changed files with 459 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,6 @@ dmypy.json

# IDEs
.vscode

# ML tensorboard
*events.out*
43 changes: 43 additions & 0 deletions experiments/mnist/flax/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
## MNIST classification

Trains a simple convolutional network on the MNIST dataset.

You can run this code and even modify it directly in Google Colab, no
installation required:

https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mnist.ipynb

### Requirements
* TensorFlow dataset `mnist` will be downloaded and prepared automatically, if necessary

### Example output

| Name | Epochs | Walltime | Top-1 accuracy | Metrics | Workdir |
| :------ | -----: | :------- | :------------- | :---------- | :---------------------------------------- |
| default | 10 | 7.7m | 99.17% | [tfhub.dev] | [gs://flax_public/examples/mnist/default] |

[tfhub.dev]: https://tensorboard.dev/experiment/1G9SvrW5RQyojRtMKNmMuQ/#scalars&_smoothingWeight=0&regexInput=default
[gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default

```
I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69
I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14
```

### How to run

`python main.py --workdir=/tmp/mnist --config=configs/default.py`

#### Overriding Hyperparameter configurations

MNIST example allows specifying a hyperparameter configuration by the means of
setting `--config` flag. Configuration flag is defined using
[config_flags](https://github.com/google/ml_collections/tree/master#config-flags).
`config_flags` allows overriding configuration fields. This can be done as
follows:

```shell
python main.py \
--workdir=/tmp/mnist --config=configs/default.py \
--config.learning_rate=0.05 --config.num_epochs=5
```
Empty file.
32 changes: 32 additions & 0 deletions experiments/mnist/flax/configs/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Default Hyperparameter configuration."""

import ml_collections


def get_config():
"""Get the default hyperparameter configuration."""
config = ml_collections.ConfigDict()

config.learning_rate = 0.1
config.momentum = 0.9
config.batch_size = 128
config.num_epochs = 10
return config


def metrics():
return []
62 changes: 62 additions & 0 deletions experiments/mnist/flax/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Main file for running the MNIST example.
This file is intentionally kept short. The majority of logic is in libraries
than can be easily tested and imported in Colab.
"""

import jax
import tensorflow as tf
import train
from absl import app, flags, logging
from clu import platform
from ml_collections import config_flags

FLAGS = flags.FLAGS

flags.DEFINE_string("workdir", None, "Directory to store model data.")
config_flags.DEFINE_config_file(
"config",
None,
"File path to the training hyperparameter configuration.",
lock_config=True,
)


def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], "GPU")

logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
logging.info("JAX local devices: %r", jax.local_devices())

# Add a note so that we can tell which task is which JAX host.
# (Depending on the platform task 0 is not guaranteed to be host 0)
platform.work_unit().set_task_status(
f"process_index: {jax.process_index()}, " f"process_count: {jax.process_count()}"
)
platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir")

train.train_and_evaluate(FLAGS.config, FLAGS.workdir)


if __name__ == "__main__":
flags.mark_flags_as_required(["config", "workdir"])
app.run(main)
82 changes: 82 additions & 0 deletions experiments/mnist/flax/mnist_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Benchmark for the MNIST example."""
import time

import jax
import main
import numpy as np
from absl import flags
from absl.testing import absltest
from absl.testing.flagsaver import flagsaver
from configs import default
from flax.testing import Benchmark

# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()

FLAGS = flags.FLAGS


class MnistBenchmark(Benchmark):
"""Benchmarks for the MNIST Flax example."""

@flagsaver
def test_cpu(self):
"""Run full training for MNIST CPU training."""
# Prepare and set flags defined in main.py.
workdir = self.get_tmp_model_dir()
config = default.get_config()

FLAGS.workdir = workdir
FLAGS.config = config

start_time = time.time()
main.main([])
benchmark_time = time.time() - start_time

summaries = self.read_summaries(workdir)

# Summaries contain all the information necessary for the regression
# metrics.
wall_time, _, eval_accuracy = zip(*summaries["eval_accuracy"])
wall_time = np.array(wall_time)
sec_per_epoch = np.mean(wall_time[1:] - wall_time[:-1])
end_eval_accuracy = eval_accuracy[-1]

# Assertions are deferred until the test finishes, so the metrics are
# always reported and benchmark success is determined based on *all*
# assertions.
self.assertBetween(end_eval_accuracy, 0.98, 1.0)

# Use the reporting API to report single or multiple metrics/extras.
self.report_wall_time(benchmark_time)
self.report_metrics(
{
"sec_per_epoch": sec_per_epoch,
"accuracy": end_eval_accuracy,
}
)
self.report_extras(
{
"model_name": "MNIST",
"description": "CPU test for MNIST.",
"implementation": "linen",
}
)


if __name__ == "__main__":
absltest.main()
5 changes: 5 additions & 0 deletions experiments/mnist/flax/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
clu
flax
ml-collections
optax
tensorflow-datasets
161 changes: 161 additions & 0 deletions experiments/mnist/flax/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MNIST example.
Library file which executes the training and evaluation loop for MNIST.
The data is loaded using tensorflow_datasets.
"""

# See issue #620.
# pytype: disable=wrong-keyword-args

import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
import tensorflow_datasets as tfds
from absl import logging
from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state


class CNN(nn.Module):
"""A simple CNN model."""

@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x


@jax.jit
def apply_model(state, images, labels):
"""Computes gradients, loss and accuracy for a single batch."""

def loss_fn(params):
logits = state.apply_fn({"params": params}, images)
one_hot = jax.nn.one_hot(labels, 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return grads, loss, accuracy


@jax.jit
def update_model(state, grads):
return state.apply_gradients(grads=grads)


def train_epoch(state, train_ds, batch_size, rng):
"""Train for a single epoch."""
train_ds_size = len(train_ds["image"])
steps_per_epoch = train_ds_size // batch_size

perms = jax.random.permutation(rng, len(train_ds["image"]))
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))

epoch_loss = []
epoch_accuracy = []

for perm in perms:
batch_images = train_ds["image"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
state = update_model(state, grads)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
train_loss = np.mean(epoch_loss)
train_accuracy = np.mean(epoch_accuracy)
return state, train_loss, train_accuracy


def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder("mnist")
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split="train", batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1))
train_ds["image"] = jnp.float32(train_ds["image"]) / 255.0
test_ds["image"] = jnp.float32(test_ds["image"]) / 255.0
return train_ds, test_ds


def create_train_state(rng, config):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))["params"]
tx = optax.sgd(config.learning_rate, config.momentum)
return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)


def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train_state.TrainState:
"""Execute model training and evaluation loop.
Args:
config: Hyperparameter configuration for training and evaluation.
workdir: Directory where the tensorboard summaries are written to.
Returns:
The train state (which includes the `.params`).
"""
train_ds, test_ds = get_datasets()
rng = jax.random.key(0)

summary_writer = tensorboard.SummaryWriter(workdir)
summary_writer.hparams(dict(config))

rng, init_rng = jax.random.split(rng)
init_rng = jax.random.PRNGKey(1)
state = create_train_state(init_rng, config)

for epoch in range(1, config.num_epochs + 1):
rng, input_rng = jax.random.split(rng)
state, train_loss, train_accuracy = train_epoch(state, train_ds, config.batch_size, input_rng)
_, test_loss, test_accuracy = apply_model(state, test_ds["image"], test_ds["label"])

logging.info(
"epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f,"
" test_accuracy: %.2f"
% (
epoch,
train_loss,
train_accuracy * 100,
test_loss,
test_accuracy * 100,
)
)

summary_writer.scalar("train_loss", train_loss, epoch)
summary_writer.scalar("train_accuracy", train_accuracy, epoch)
summary_writer.scalar("test_loss", test_loss, epoch)
summary_writer.scalar("test_accuracy", test_accuracy, epoch)

summary_writer.flush()
return state
Loading

0 comments on commit a870889

Please sign in to comment.