Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Dec 4, 2023
1 parent de39aca commit 85774af
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions experiments/mnist/flax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ def train_epoch(state, train_ds, batch_size, rng):
for perm in perms:
batch_images = train_ds["image"][perm, ...]
batch_labels = train_ds["label"][perm, ...]

# jaxpr = jax.make_jaxpr(apply_and_update_model)(state, batch_images, batch_labels).jaxpr
# print(jaxpr)

# print(set([v.primitive for v in jaxpr.eqns]))
# assert False

state, loss, accuracy = apply_and_update_model(state, batch_images, batch_labels)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
Expand Down Expand Up @@ -145,9 +152,9 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train
init_rng = jax.random.PRNGKey(1)

state = create_train_state(init_rng, config)
print(state)
print(jax.tree_util.tree_flatten(state))
state = jsa.as_scaled_array(state)
# print(state)
# print(jax.tree_util.tree_flatten(state))
# state = jsa.as_scaled_array(state)

logging.info("Start Flax MNIST training...")

Expand Down

0 comments on commit 85774af

Please sign in to comment.