Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Dec 4, 2023
1 parent a870889 commit de39aca
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions experiments/mnist/flax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from flax.metrics import tensorboard
from flax.training import train_state

import jax_scaled_arithmetics as jsa


class CNN(nn.Module):
"""A simple CNN model."""
Expand Down Expand Up @@ -72,6 +74,15 @@ def update_model(state, grads):
return state.apply_gradients(grads=grads)


@jax.jit
@jsa.autoscale
def apply_and_update_model(state, batch_images, batch_labels):
# Jitting together forward + backward + update.
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
state = update_model(state, grads)
return state, loss, accuracy


def train_epoch(state, train_ds, batch_size, rng):
"""Train for a single epoch."""
train_ds_size = len(train_ds["image"])
Expand All @@ -87,8 +98,7 @@ def train_epoch(state, train_ds, batch_size, rng):
for perm in perms:
batch_images = train_ds["image"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
state = update_model(state, grads)
state, loss, accuracy = apply_and_update_model(state, batch_images, batch_labels)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
train_loss = np.mean(epoch_loss)
Expand Down Expand Up @@ -133,7 +143,13 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train

rng, init_rng = jax.random.split(rng)
init_rng = jax.random.PRNGKey(1)

state = create_train_state(init_rng, config)
print(state)
print(jax.tree_util.tree_flatten(state))
state = jsa.as_scaled_array(state)

logging.info("Start Flax MNIST training...")

for epoch in range(1, config.num_epochs + 1):
rng, input_rng = jax.random.split(rng)
Expand Down

0 comments on commit de39aca

Please sign in to comment.