Skip to content

Commit

Permalink
Reorganize and clean MNIST + CIFAR10 examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jun 17, 2024
1 parent b3ed77c commit 0d52d7a
Show file tree
Hide file tree
Showing 17 changed files with 194 additions and 510 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ As presented in the code above, the model state is represented as a JAX PyTree o

A full collection of examples is available:
* [Scalify quickstart notebook](./examples/scalify-quickstart.ipynb): basics of `ScaledArray` and `scalify` transform;
* [MNIST FP16 training example](./experiments/mnist/mnist_classifier_from_scratch.py): adapting JAX MNIST example to `scalify`;
* [MNIST FP8 training example](./experiments/mnist/mnist_classifier_from_scratch.py): easy FP8 support in `scalify`;
* [CIFAR10 training](./experiments/mnist/cifar_training.py): `scalify` CIFAR10 training, with Optax optimizer integration;
* [MNIST FP16 training example](./examples/mnist/mnist_classifier_from_scratch.py): adapting JAX MNIST example to `scalify`;
* [MNIST FP8 training example](./examples/mnist/mnist_classifier_from_scratch_fp8.py): easy FP8 support in `scalify`;
* [MNIST Flax example](./examples/mnist/flax): `scalify` Flax training, with Optax optimizer integration;


## Installation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified by Graphcore Ltd 2024.

"""A basic MNIST example using Numpy and JAX.
"""A basic CIFAR10 example using Numpy and JAX.
The primary aim here is simplicity and minimal dependencies.
CIFAR10 training using MLP network + raw SGD optimizer.
"""


import time

import datasets
import dataset_cifar10
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -100,7 +99,7 @@ def accuracy(params, batch):
training_dtype = np.float16
scale_dtype = np.float32

train_images, train_labels, test_images, test_labels = datasets.cifar()
train_images, train_labels, test_images, test_labels = dataset_cifar10.cifar()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
Expand All @@ -118,7 +117,7 @@ def data_stream():
# Transform parameters to `ScaledArray` and proper dtype.
if use_scalify:
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)
params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)

@jit
@scalify
Expand All @@ -133,7 +132,7 @@ def update(params, batch):
# Scaled micro-batch + training dtype cast.
if use_scalify:
batch = jsa.as_scaled_array(batch, scale=scale_dtype(param_scale))
batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)
batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)

with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
params = update(params, batch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified by Graphcore Ltd 2024.

"""A basic MNIST example using Numpy and JAX.
The primary aim here is simplicity and minimal dependencies.
"""A basic CIFAR10 example using Numpy and JAX.
"""


import time

import datasets
import dataset_cifar10
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -65,10 +64,6 @@ def predict(params, inputs):

final_w, final_b = params[-1]
logits = jnp.dot(activations, final_w) + final_b

# jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)

# Dynamic rescaling of the gradient, as logits gradient not properly scaled.
logits = jsa.ops.dynamic_rescale_l2_grad(logits)
output = logits - logsumexp(logits, axis=1, keepdims=True)
Expand Down Expand Up @@ -102,7 +97,7 @@ def accuracy(params, batch):
batch_size = 128
scale_dtype = np.float32

train_images, train_labels, test_images, test_labels = datasets.cifar()
train_images, train_labels, test_images, test_labels = dataset_cifar10.cifar()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
Expand Down
154 changes: 154 additions & 0 deletions examples/cifar10/dataset_cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified by Graphcore Ltd 2024.

"""Datasets used in examples."""


import array
import gzip
import os
import pickle
import struct
import tarfile
import urllib.request
from os import path

import numpy as np

_DATA = "/tmp/jax_example_data/"


def _download(url, filename):
"""Download a url to a file in the JAX data temp directory."""
if not path.exists(_DATA):
os.makedirs(_DATA)
out_file = path.join(_DATA, filename)
if not path.isfile(out_file):
urllib.request.urlretrieve(url, out_file)
print(f"downloaded {url} to {_DATA}")


def _partial_flatten(x):
"""Flatten all but the first dimension of an ndarray."""
return np.reshape(x, (x.shape[0], -1))


def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)


def _unzip(file):
file = tarfile.open(file)
file.extractall(_DATA)
file.close()
return


def _unpickle(file):
with open(file, "rb") as fo:
dict = pickle.load(fo, encoding="bytes")
return dict


def mnist_raw():
"""Download and parse the raw MNIST dataset."""
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"

def parse_labels(filename):
with gzip.open(filename, "rb") as fh:
_ = struct.unpack(">II", fh.read(8))
return np.array(array.array("B", fh.read()), dtype=np.uint8)

def parse_images(filename):
with gzip.open(filename, "rb") as fh:
_, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(num_data, rows, cols)

for filename in [
"train-images-idx3-ubyte.gz",
"train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz",
]:
_download(base_url + filename, filename)

train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz"))
train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz"))
test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz"))
test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz"))

return train_images, train_labels, test_images, test_labels


def mnist(permute_train=False):
"""Download, parse and process MNIST data to unit scale and one-hot labels."""
train_images, train_labels, test_images, test_labels = mnist_raw()

train_images = _partial_flatten(train_images) / np.float32(255.0)
test_images = _partial_flatten(test_images) / np.float32(255.0)
train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)

if permute_train:
perm = np.random.RandomState(0).permutation(train_images.shape[0])
train_images = train_images[perm]
train_labels = train_labels[perm]

return train_images, train_labels, test_images, test_labels


def cifar_raw():
"""Download, unzip and parse the raw cifar dataset."""

filename = "cifar-10-python.tar.gz"
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
_download(url, filename)
_unzip(path.join(_DATA, filename))

data_batches = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"]
data = []
labels = []
for batch in data_batches:
tmp_dict = _unpickle(path.join(_DATA, "cifar-10-batches-py", batch))
data.append(tmp_dict[b"data"])
labels.append(tmp_dict[b"labels"])
train_images = np.concatenate(data)
train_labels = np.concatenate(labels)

test_dict = _unpickle(path.join(_DATA, "cifar-10-batches-py", "test_batch"))
test_images = test_dict[b"data"]
test_labels = np.array(test_dict[b"labels"])

return train_images, train_labels, test_images, test_labels


def cifar(permute_train=False):
"""Download, parse and process cifar data to unit scale and one-hot labels."""

train_images, train_labels, test_images, test_labels = cifar_raw()

train_images = train_images / np.float32(255.0)
test_images = test_images / np.float32(255.0)
train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)

if permute_train:
perm = np.random.RandomState(0).permutation(train_images.shape[0])
train_images = train_images[perm]
train_labels = train_labels[perm]

return train_images, train_labels, test_images, test_labels
2 changes: 2 additions & 0 deletions experiments/mnist/datasets.py → examples/mnist/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified by Graphcore Ltd 2024.


"""Datasets used in examples."""

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
"""

import jax
import tensorflow as tf

# import tensorflow as tf
import train
from absl import app, flags, logging
from clu import platform
Expand All @@ -42,7 +43,7 @@ def main(argv):

# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], "GPU")
# tf.config.experimental.set_visible_devices([], "GPU")

logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
logging.info("JAX local devices: %r", jax.local_devices())
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
import optax
import tensorflow_datasets as tfds
from absl import logging
from flax import linen as nn
from flax.metrics import tensorboard
from flax import linen as nn # type:ignore

# from flax.metrics import tensorboard
from flax.training import train_state

import jax_scalify as jsa
Expand Down Expand Up @@ -143,8 +144,8 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train
train_ds, test_ds = get_datasets()
rng = jax.random.key(0)

summary_writer = tensorboard.SummaryWriter(workdir)
summary_writer.hparams(dict(config))
# summary_writer = tensorboard.SummaryWriter(workdir)
# summary_writer.hparams(dict(config))

rng, init_rng = jax.random.split(rng)
init_rng = jax.random.PRNGKey(1)
Expand Down Expand Up @@ -173,10 +174,10 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train
)
)

summary_writer.scalar("train_loss", train_loss, epoch)
summary_writer.scalar("train_accuracy", train_accuracy, epoch)
summary_writer.scalar("test_loss", test_loss, epoch)
summary_writer.scalar("test_accuracy", test_accuracy, epoch)
# summary_writer.scalar("train_loss", train_loss, epoch)
# summary_writer.scalar("train_accuracy", train_accuracy, epoch)
# summary_writer.scalar("test_loss", test_loss, epoch)
# summary_writer.scalar("test_accuracy", test_accuracy, epoch)

summary_writer.flush()
# summary_writer.flush()
return state
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified by Graphcore Ltd 2024.

"""A basic MNIST example using Numpy and JAX.
Expand Down Expand Up @@ -78,9 +79,9 @@ def accuracy(params, batch):


if __name__ == "__main__":
layer_sizes = [784, 1024, 1024, 10]
param_scale = 1.0
step_size = 0.001
layer_sizes = [784, 512, 512, 10]
param_scale = 0.1
step_size = 0.1
num_epochs = 10
batch_size = 128

Expand Down Expand Up @@ -125,7 +126,7 @@ def update(params, batch):

epoch_time = time.time() - start_time

# Evaluation in float32, for consistency.
# Evaluation in normal/unscaled float32, for consistency.
raw_params = jsa.asarray(params, dtype=np.float32)
train_acc = accuracy(raw_params, (train_images, train_labels))
test_acc = accuracy(raw_params, (test_images, test_labels))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified by Graphcore Ltd 2024.

"""A basic MNIST example using Numpy and JAX.
Expand All @@ -34,6 +35,7 @@


def print_mean_std(name, v):
"""Debugging method/tool for JAX Scalify."""
data, scale = jsa.lax.get_data_scale(v)
# Always use np.float32, to avoid floating errors in descaling + stats.
data = jsa.asarray(data, dtype=np.float32)
Expand Down Expand Up @@ -105,9 +107,9 @@ def accuracy(params, batch):


if __name__ == "__main__":
layer_sizes = [784, 1024, 1024, 10]
param_scale = 1.0
step_size = 0.001
layer_sizes = [784, 512, 512, 10]
param_scale = 0.1
step_size = 0.1
num_epochs = 10
batch_size = 128

Expand Down
Loading

0 comments on commit 0d52d7a

Please sign in to comment.