Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNIST grid search experiments plots #92

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions experiments/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
The primary aim here is simplicity and minimal dependencies.
"""


import json
import os
import time

import datasets
Expand Down Expand Up @@ -76,14 +77,29 @@ def accuracy(params, batch):
return jnp.mean(predicted_class == target_class)


def update_experiments_json(filename, config, results):
print("Saving results in:", filename)
experiments = []
if os.path.exists(filename):
with open(filename) as f:
experiments = json.load(f)
experiments.append((config, results))
with open(filename, "w") as f:
json.dump(experiments, f, indent=4)


if __name__ == "__main__":
# Param scales: 0.5, 1, 2, 4, 8
# Step size: 0.0005, 0.001, 0.002, 0.004, 0.008, 0.016, 0.03

layer_sizes = [784, 1024, 1024, 10]
param_scale = 1.0
step_size = 0.001
num_epochs = 10
batch_size = 128

training_dtype = np.float16
use_autoscale = False
training_dtype = np.float32
scale_dtype = np.float32

train_images, train_labels, test_images, test_labels = datasets.mnist()
Expand All @@ -102,21 +118,27 @@ def data_stream():
batches = data_stream()
params = init_random_params(param_scale, layer_sizes)
# Transform parameters to `ScaledArray` and proper dtype.
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
if use_autoscale:
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)

@jit
@jsa.autoscale
def update(params, batch):
grads = grad(loss)(params, batch)
return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]

if use_autoscale:
update = jax.jit(jsa.autoscale(update))

# num_epochs = 1

for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
batch = next(batches)
# Scaled micro-batch + training dtype cast.
batch = jsa.as_scaled_array(batch, scale=scale_dtype(1))
if use_autoscale:
batch = jsa.as_scaled_array(batch, scale=scale_dtype(1))
batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)

with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
Expand All @@ -131,3 +153,15 @@ def update(params, batch):
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
print(f"Training set accuracy {train_acc:0.5f}")
print(f"Test set accuracy {test_acc:0.5f}")

filename = os.path.join(os.path.dirname(__file__), "mnist_experiments.json")
config = (
param_scale,
step_size,
num_epochs,
use_autoscale,
str(np.dtype(training_dtype)),
str(np.dtype(scale_dtype)),
)
results = (float(train_acc), float(test_acc))
update_experiments_json(filename, config, results)
Loading