diff --git a/experiments/mnist/flax/train.py b/experiments/mnist/flax/train.py index 44b8b78..0c520e8 100644 --- a/experiments/mnist/flax/train.py +++ b/experiments/mnist/flax/train.py @@ -98,6 +98,13 @@ def train_epoch(state, train_ds, batch_size, rng): for perm in perms: batch_images = train_ds["image"][perm, ...] batch_labels = train_ds["label"][perm, ...] + + # jaxpr = jax.make_jaxpr(apply_and_update_model)(state, batch_images, batch_labels).jaxpr + # print(jaxpr) + + # print(set([v.primitive for v in jaxpr.eqns])) + # assert False + state, loss, accuracy = apply_and_update_model(state, batch_images, batch_labels) epoch_loss.append(loss) epoch_accuracy.append(accuracy) @@ -145,9 +152,9 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train init_rng = jax.random.PRNGKey(1) state = create_train_state(init_rng, config) - print(state) - print(jax.tree_util.tree_flatten(state)) - state = jsa.as_scaled_array(state) + # print(state) + # print(jax.tree_util.tree_flatten(state)) + # state = jsa.as_scaled_array(state) logging.info("Start Flax MNIST training...")