Skip to content

Commit

Permalink
Adding FLAX MNIST training example. (#46)
Browse files Browse the repository at this point in the history
Helps improving operation coverage in `autoscale` and finding edge cases.
The training could be tested using: python ./experiments/mnist/flax_example/main.py --workdir ./experiments/mnist/flax_example/results --config ./experiments/mnist/flax_example/configs/default.py

NOTE: We still need to understand why some operations now requires promotion to scaled array. Any idea why this is required?
  • Loading branch information
balancap authored Dec 5, 2023
1 parent 8fd8b90 commit 9d02472
Show file tree
Hide file tree
Showing 8 changed files with 343 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,6 @@ dmypy.json

# IDEs
.vscode

# ML tensorboard
*events.out*
43 changes: 43 additions & 0 deletions experiments/mnist/flax_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
## MNIST classification

Trains a simple convolutional network on the MNIST dataset.

You can run this code and even modify it directly in Google Colab, no
installation required:

https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mnist.ipynb

### Requirements
* TensorFlow dataset `mnist` will be downloaded and prepared automatically, if necessary

### Example output

| Name | Epochs | Walltime | Top-1 accuracy | Metrics | Workdir |
| :------ | -----: | :------- | :------------- | :---------- | :---------------------------------------- |
| default | 10 | 7.7m | 99.17% | [tfhub.dev] | [gs://flax_public/examples/mnist/default] |

[tfhub.dev]: https://tensorboard.dev/experiment/1G9SvrW5RQyojRtMKNmMuQ/#scalars&_smoothingWeight=0&regexInput=default
[gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default

```
I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69
I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14
```

### How to run

`python main.py --workdir=/tmp/mnist --config=configs/default.py`

#### Overriding Hyperparameter configurations

MNIST example allows specifying a hyperparameter configuration by the means of
setting `--config` flag. Configuration flag is defined using
[config_flags](https://github.com/google/ml_collections/tree/master#config-flags).
`config_flags` allows overriding configuration fields. This can be done as
follows:

```shell
python main.py \
--workdir=/tmp/mnist --config=configs/default.py \
--config.learning_rate=0.05 --config.num_epochs=5
```
Empty file.
32 changes: 32 additions & 0 deletions experiments/mnist/flax_example/configs/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Default Hyperparameter configuration."""

import ml_collections


def get_config():
"""Get the default hyperparameter configuration."""
config = ml_collections.ConfigDict()

config.learning_rate = 0.1
config.momentum = 0.9
config.batch_size = 128
config.num_epochs = 10
return config


def metrics():
return []
62 changes: 62 additions & 0 deletions experiments/mnist/flax_example/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Main file for running the MNIST example.
This file is intentionally kept short. The majority of logic is in libraries
than can be easily tested and imported in Colab.
"""

import jax
import tensorflow as tf
import train
from absl import app, flags, logging
from clu import platform
from ml_collections import config_flags

FLAGS = flags.FLAGS

flags.DEFINE_string("workdir", None, "Directory to store model data.")
config_flags.DEFINE_config_file(
"config",
None,
"File path to the training hyperparameter configuration.",
lock_config=True,
)


def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], "GPU")

logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
logging.info("JAX local devices: %r", jax.local_devices())

# Add a note so that we can tell which task is which JAX host.
# (Depending on the platform task 0 is not guaranteed to be host 0)
platform.work_unit().set_task_status(
f"process_index: {jax.process_index()}, " f"process_count: {jax.process_count()}"
)
platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir")

train.train_and_evaluate(FLAGS.config, FLAGS.workdir)


if __name__ == "__main__":
flags.mark_flags_as_required(["config", "workdir"])
app.run(main)
5 changes: 5 additions & 0 deletions experiments/mnist/flax_example/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
clu
flax
ml-collections
optax
tensorflow-datasets
182 changes: 182 additions & 0 deletions experiments/mnist/flax_example/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MNIST example.
Library file which executes the training and evaluation loop for MNIST.
The data is loaded using tensorflow_datasets.
"""

# See issue #620.
# pytype: disable=wrong-keyword-args

import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
import tensorflow_datasets as tfds
from absl import logging
from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state

import jax_scaled_arithmetics as jsa


class CNN(nn.Module):
"""A simple CNN model."""

@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x


@jax.jit
def apply_model(state, images, labels):
"""Computes gradients, loss and accuracy for a single batch."""

def loss_fn(params):
logits = state.apply_fn({"params": params}, images)
one_hot = jax.nn.one_hot(labels, 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return grads, loss, accuracy


@jax.jit
def update_model(state, grads):
return state.apply_gradients(grads=grads)


@jax.jit
@jsa.autoscale
def apply_and_update_model(state, batch_images, batch_labels):
# Jitting together forward + backward + update.
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
state = update_model(state, grads)
return state, loss, accuracy


def train_epoch(state, train_ds, batch_size, rng):
"""Train for a single epoch."""
train_ds_size = len(train_ds["image"])
steps_per_epoch = train_ds_size // batch_size

perms = jax.random.permutation(rng, len(train_ds["image"]))
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))

epoch_loss = []
epoch_accuracy = []

for perm in perms:
batch_images = train_ds["image"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
# Transform batch to ScaledArray
batch_images = jsa.as_scaled_array(batch_images)
# Apply & update stages in scaled mode.
state, loss, accuracy = apply_and_update_model(state, batch_images, batch_labels)

epoch_loss.append(np.asarray(loss))
epoch_accuracy.append(np.asarray(accuracy))

train_loss = np.mean(epoch_loss)
train_accuracy = np.mean(epoch_accuracy)
return state, train_loss, train_accuracy


def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder("mnist")
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split="train", batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1))
train_ds["image"] = jnp.float32(train_ds["image"]) / 255.0
test_ds["image"] = jnp.float32(test_ds["image"]) / 255.0
return train_ds, test_ds


def create_train_state(rng, config):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))["params"]
tx = optax.sgd(config.learning_rate, config.momentum)
return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)


def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train_state.TrainState:
"""Execute model training and evaluation loop.
Args:
config: Hyperparameter configuration for training and evaluation.
workdir: Directory where the tensorboard summaries are written to.
Returns:
The train state (which includes the `.params`).
"""
train_ds, test_ds = get_datasets()
rng = jax.random.key(0)

summary_writer = tensorboard.SummaryWriter(workdir)
summary_writer.hparams(dict(config))

rng, init_rng = jax.random.split(rng)
init_rng = jax.random.PRNGKey(1)

state = create_train_state(init_rng, config)
# Convert model & optimizer states to `ScaledArray``
state = jsa.as_scaled_array(state)

logging.info("Start Flax MNIST training...")

for epoch in range(1, config.num_epochs + 1):
rng, input_rng = jax.random.split(rng)
state, train_loss, train_accuracy = train_epoch(state, train_ds, config.batch_size, input_rng)
# NOTE: running evaluation on the plain normal arrays.
_, test_loss, test_accuracy = apply_model(jsa.asarray(state), test_ds["image"], test_ds["label"])

logging.info(
"epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f,"
" test_accuracy: %.2f"
% (
epoch,
train_loss,
train_accuracy * 100,
test_loss,
test_accuracy * 100,
)
)

summary_writer.scalar("train_loss", train_loss, epoch)
summary_writer.scalar("train_accuracy", train_accuracy, epoch)
summary_writer.scalar("test_loss", test_loss, epoch)
summary_writer.scalar("test_accuracy", test_accuracy, epoch)

summary_writer.flush()
return state
20 changes: 16 additions & 4 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@
from jax._src.ad_util import add_any_p

from jax_scaled_arithmetics import core
from jax_scaled_arithmetics.core import Array, DTypeLike, ScaledArray, Shape, as_scaled_array, register_scaled_op
from jax_scaled_arithmetics.core import (
Array,
DTypeLike,
ScaledArray,
Shape,
as_scaled_array,
is_static_zero,
register_scaled_op,
)

from .base_scaling_primitives import scaled_set_scaling

Expand Down Expand Up @@ -100,8 +108,8 @@ def scaled_rev(val: ScaledArray, dimensions: Sequence[int]) -> ScaledArray:
@core.register_scaled_lax_op
def scaled_pad(val: ScaledArray, padding_value: Any, padding_config: Any) -> ScaledArray:
# Only supporting constant zero padding for now.
assert float(padding_value) == 0.0
return ScaledArray(lax.pad(val.data, padding_value, padding_config), val.scale)
assert np.all(is_static_zero(padding_value))
return ScaledArray(lax.pad(val.data, np.array(0, val.dtype), padding_config), val.scale)


@core.register_scaled_lax_op
Expand All @@ -128,19 +136,23 @@ def scaled_abs(val: ScaledArray) -> ScaledArray:

@core.register_scaled_lax_op
def scaled_mul(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
# TODO: understand why/when this conversion kicks in?
# TODO: understand when promotion is really required?
lhs, rhs = as_scaled_array((lhs, rhs)) # type:ignore
return ScaledArray(lhs.data * rhs.data, lhs.scale * rhs.scale)


@core.register_scaled_lax_op
def scaled_div(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
# TODO: understand when promotion is really required?
lhs, rhs = as_scaled_array((lhs, rhs)) # type:ignore
# TODO: investigate different rule?
return ScaledArray(lhs.data / rhs.data, lhs.scale / rhs.scale)


@core.register_scaled_lax_op
def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray:
# TODO: understand when promotion is really required?
A, B = as_scaled_array((A, B)) # type:ignore
check_scalar_scales(A, B)
A, B = promote_scale_types(A, B)
assert np.issubdtype(A.scale.dtype, np.floating)
Expand Down

0 comments on commit 9d02472

Please sign in to comment.