-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding FLAX MNIST training example. (#46)
Helps improving operation coverage in `autoscale` and finding edge cases. The training could be tested using: python ./experiments/mnist/flax_example/main.py --workdir ./experiments/mnist/flax_example/results --config ./experiments/mnist/flax_example/configs/default.py NOTE: We still need to understand why some operations now requires promotion to scaled array. Any idea why this is required?
- Loading branch information
Showing
8 changed files
with
343 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,3 +133,6 @@ dmypy.json | |
|
||
# IDEs | ||
.vscode | ||
|
||
# ML tensorboard | ||
*events.out* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
## MNIST classification | ||
|
||
Trains a simple convolutional network on the MNIST dataset. | ||
|
||
You can run this code and even modify it directly in Google Colab, no | ||
installation required: | ||
|
||
https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mnist.ipynb | ||
|
||
### Requirements | ||
* TensorFlow dataset `mnist` will be downloaded and prepared automatically, if necessary | ||
|
||
### Example output | ||
|
||
| Name | Epochs | Walltime | Top-1 accuracy | Metrics | Workdir | | ||
| :------ | -----: | :------- | :------------- | :---------- | :---------------------------------------- | | ||
| default | 10 | 7.7m | 99.17% | [tfhub.dev] | [gs://flax_public/examples/mnist/default] | | ||
|
||
[tfhub.dev]: https://tensorboard.dev/experiment/1G9SvrW5RQyojRtMKNmMuQ/#scalars&_smoothingWeight=0®exInput=default | ||
[gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default | ||
|
||
``` | ||
I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69 | ||
I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14 | ||
``` | ||
|
||
### How to run | ||
|
||
`python main.py --workdir=/tmp/mnist --config=configs/default.py` | ||
|
||
#### Overriding Hyperparameter configurations | ||
|
||
MNIST example allows specifying a hyperparameter configuration by the means of | ||
setting `--config` flag. Configuration flag is defined using | ||
[config_flags](https://github.com/google/ml_collections/tree/master#config-flags). | ||
`config_flags` allows overriding configuration fields. This can be done as | ||
follows: | ||
|
||
```shell | ||
python main.py \ | ||
--workdir=/tmp/mnist --config=configs/default.py \ | ||
--config.learning_rate=0.05 --config.num_epochs=5 | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright 2023 The Flax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Default Hyperparameter configuration.""" | ||
|
||
import ml_collections | ||
|
||
|
||
def get_config(): | ||
"""Get the default hyperparameter configuration.""" | ||
config = ml_collections.ConfigDict() | ||
|
||
config.learning_rate = 0.1 | ||
config.momentum = 0.9 | ||
config.batch_size = 128 | ||
config.num_epochs = 10 | ||
return config | ||
|
||
|
||
def metrics(): | ||
return [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright 2023 The Flax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Main file for running the MNIST example. | ||
This file is intentionally kept short. The majority of logic is in libraries | ||
than can be easily tested and imported in Colab. | ||
""" | ||
|
||
import jax | ||
import tensorflow as tf | ||
import train | ||
from absl import app, flags, logging | ||
from clu import platform | ||
from ml_collections import config_flags | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_string("workdir", None, "Directory to store model data.") | ||
config_flags.DEFINE_config_file( | ||
"config", | ||
None, | ||
"File path to the training hyperparameter configuration.", | ||
lock_config=True, | ||
) | ||
|
||
|
||
def main(argv): | ||
if len(argv) > 1: | ||
raise app.UsageError("Too many command-line arguments.") | ||
|
||
# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make | ||
# it unavailable to JAX. | ||
tf.config.experimental.set_visible_devices([], "GPU") | ||
|
||
logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) | ||
logging.info("JAX local devices: %r", jax.local_devices()) | ||
|
||
# Add a note so that we can tell which task is which JAX host. | ||
# (Depending on the platform task 0 is not guaranteed to be host 0) | ||
platform.work_unit().set_task_status( | ||
f"process_index: {jax.process_index()}, " f"process_count: {jax.process_count()}" | ||
) | ||
platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") | ||
|
||
train.train_and_evaluate(FLAGS.config, FLAGS.workdir) | ||
|
||
|
||
if __name__ == "__main__": | ||
flags.mark_flags_as_required(["config", "workdir"]) | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
clu | ||
flax | ||
ml-collections | ||
optax | ||
tensorflow-datasets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# Copyright 2023 The Flax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""MNIST example. | ||
Library file which executes the training and evaluation loop for MNIST. | ||
The data is loaded using tensorflow_datasets. | ||
""" | ||
|
||
# See issue #620. | ||
# pytype: disable=wrong-keyword-args | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import ml_collections | ||
import numpy as np | ||
import optax | ||
import tensorflow_datasets as tfds | ||
from absl import logging | ||
from flax import linen as nn | ||
from flax.metrics import tensorboard | ||
from flax.training import train_state | ||
|
||
import jax_scaled_arithmetics as jsa | ||
|
||
|
||
class CNN(nn.Module): | ||
"""A simple CNN model.""" | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
x = nn.Conv(features=32, kernel_size=(3, 3))(x) | ||
x = nn.relu(x) | ||
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) | ||
x = nn.Conv(features=64, kernel_size=(3, 3))(x) | ||
x = nn.relu(x) | ||
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) | ||
x = x.reshape((x.shape[0], -1)) # flatten | ||
x = nn.Dense(features=256)(x) | ||
x = nn.relu(x) | ||
x = nn.Dense(features=10)(x) | ||
return x | ||
|
||
|
||
@jax.jit | ||
def apply_model(state, images, labels): | ||
"""Computes gradients, loss and accuracy for a single batch.""" | ||
|
||
def loss_fn(params): | ||
logits = state.apply_fn({"params": params}, images) | ||
one_hot = jax.nn.one_hot(labels, 10) | ||
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) | ||
return loss, logits | ||
|
||
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) | ||
(loss, logits), grads = grad_fn(state.params) | ||
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) | ||
return grads, loss, accuracy | ||
|
||
|
||
@jax.jit | ||
def update_model(state, grads): | ||
return state.apply_gradients(grads=grads) | ||
|
||
|
||
@jax.jit | ||
@jsa.autoscale | ||
def apply_and_update_model(state, batch_images, batch_labels): | ||
# Jitting together forward + backward + update. | ||
grads, loss, accuracy = apply_model(state, batch_images, batch_labels) | ||
state = update_model(state, grads) | ||
return state, loss, accuracy | ||
|
||
|
||
def train_epoch(state, train_ds, batch_size, rng): | ||
"""Train for a single epoch.""" | ||
train_ds_size = len(train_ds["image"]) | ||
steps_per_epoch = train_ds_size // batch_size | ||
|
||
perms = jax.random.permutation(rng, len(train_ds["image"])) | ||
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch | ||
perms = perms.reshape((steps_per_epoch, batch_size)) | ||
|
||
epoch_loss = [] | ||
epoch_accuracy = [] | ||
|
||
for perm in perms: | ||
batch_images = train_ds["image"][perm, ...] | ||
batch_labels = train_ds["label"][perm, ...] | ||
# Transform batch to ScaledArray | ||
batch_images = jsa.as_scaled_array(batch_images) | ||
# Apply & update stages in scaled mode. | ||
state, loss, accuracy = apply_and_update_model(state, batch_images, batch_labels) | ||
|
||
epoch_loss.append(np.asarray(loss)) | ||
epoch_accuracy.append(np.asarray(accuracy)) | ||
|
||
train_loss = np.mean(epoch_loss) | ||
train_accuracy = np.mean(epoch_accuracy) | ||
return state, train_loss, train_accuracy | ||
|
||
|
||
def get_datasets(): | ||
"""Load MNIST train and test datasets into memory.""" | ||
ds_builder = tfds.builder("mnist") | ||
ds_builder.download_and_prepare() | ||
train_ds = tfds.as_numpy(ds_builder.as_dataset(split="train", batch_size=-1)) | ||
test_ds = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1)) | ||
train_ds["image"] = jnp.float32(train_ds["image"]) / 255.0 | ||
test_ds["image"] = jnp.float32(test_ds["image"]) / 255.0 | ||
return train_ds, test_ds | ||
|
||
|
||
def create_train_state(rng, config): | ||
"""Creates initial `TrainState`.""" | ||
cnn = CNN() | ||
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))["params"] | ||
tx = optax.sgd(config.learning_rate, config.momentum) | ||
return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx) | ||
|
||
|
||
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train_state.TrainState: | ||
"""Execute model training and evaluation loop. | ||
Args: | ||
config: Hyperparameter configuration for training and evaluation. | ||
workdir: Directory where the tensorboard summaries are written to. | ||
Returns: | ||
The train state (which includes the `.params`). | ||
""" | ||
train_ds, test_ds = get_datasets() | ||
rng = jax.random.key(0) | ||
|
||
summary_writer = tensorboard.SummaryWriter(workdir) | ||
summary_writer.hparams(dict(config)) | ||
|
||
rng, init_rng = jax.random.split(rng) | ||
init_rng = jax.random.PRNGKey(1) | ||
|
||
state = create_train_state(init_rng, config) | ||
# Convert model & optimizer states to `ScaledArray`` | ||
state = jsa.as_scaled_array(state) | ||
|
||
logging.info("Start Flax MNIST training...") | ||
|
||
for epoch in range(1, config.num_epochs + 1): | ||
rng, input_rng = jax.random.split(rng) | ||
state, train_loss, train_accuracy = train_epoch(state, train_ds, config.batch_size, input_rng) | ||
# NOTE: running evaluation on the plain normal arrays. | ||
_, test_loss, test_accuracy = apply_model(jsa.asarray(state), test_ds["image"], test_ds["label"]) | ||
|
||
logging.info( | ||
"epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f," | ||
" test_accuracy: %.2f" | ||
% ( | ||
epoch, | ||
train_loss, | ||
train_accuracy * 100, | ||
test_loss, | ||
test_accuracy * 100, | ||
) | ||
) | ||
|
||
summary_writer.scalar("train_loss", train_loss, epoch) | ||
summary_writer.scalar("train_accuracy", train_accuracy, epoch) | ||
summary_writer.scalar("test_loss", test_loss, epoch) | ||
summary_writer.scalar("test_accuracy", test_accuracy, epoch) | ||
|
||
summary_writer.flush() | ||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters