diff --git a/examples/mnist/mnist_classifier_from_scratch.py b/examples/mnist/mnist_classifier_from_scratch.py index fb54341..a5ae542 100644 --- a/examples/mnist/mnist_classifier_from_scratch.py +++ b/examples/mnist/mnist_classifier_from_scratch.py @@ -105,7 +105,7 @@ def data_stream(): params = init_random_params(param_scale, layer_sizes) # Transform parameters to `ScaledArray` and proper dtype. params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) - params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) + params = jsa.tree.astype(params, training_dtype) @jit @jsa.scalify @@ -119,7 +119,7 @@ def update(params, batch): batch = next(batches) # Scaled micro-batch + training dtype cast. batch = jsa.as_scaled_array(batch, scale=scale_dtype(1)) - batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) + batch = jsa.tree.astype(batch, training_dtype) with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): params = update(params, batch) diff --git a/examples/mnist/mnist_classifier_from_scratch_fp8.py b/examples/mnist/mnist_classifier_from_scratch_fp8.py index 5f142d1..0de9219 100644 --- a/examples/mnist/mnist_classifier_from_scratch_fp8.py +++ b/examples/mnist/mnist_classifier_from_scratch_fp8.py @@ -22,7 +22,6 @@ import time import datasets -import jax import jax.numpy as jnp import ml_dtypes import numpy as np @@ -133,7 +132,7 @@ def data_stream(): params = init_random_params(param_scale, layer_sizes) # Transform parameters to `ScaledArray` and proper dtype. params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) - params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) + params = jsa.tree.astype(params, training_dtype) @jit @jsa.scalify @@ -147,7 +146,7 @@ def update(params, batch): batch = next(batches) # Scaled micro-batch + training dtype cast. batch = jsa.as_scaled_array(batch, scale=scale_dtype(1)) - batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) + batch = jsa.tree.astype(batch, training_dtype) with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): params = update(params, batch) diff --git a/examples/mnist/mnist_classifier_mlp_flax.py b/examples/mnist/mnist_classifier_mlp_flax.py index 03c44d5..1a88491 100644 --- a/examples/mnist/mnist_classifier_mlp_flax.py +++ b/examples/mnist/mnist_classifier_mlp_flax.py @@ -79,8 +79,8 @@ def update(model, optimizer, model_state, opt_state, batch): key = jax.random.PRNGKey(42) use_scalify: bool = True - # training_dtype = np.dtype(np.float16) training_dtype = np.dtype(np.float16) + optimizer_dtype = np.dtype(np.float16) scale_dtype = np.float32 train_images, train_labels, test_images, test_labels = datasets.mnist() @@ -102,27 +102,24 @@ def data_stream(): model_state = model.init(key, np.zeros((batch_size, mnist_img_size), dtype=training_dtype)) # Optimizer & optimizer state. # opt = optax.sgd(learning_rate=step_size) - opt = optax.adam(learning_rate=step_size, eps=1e-5) + opt = optax.adam(learning_rate=step_size, eps=2**-16) opt_state = opt.init(model_state) # Freeze model, optimizer (with step size). update_fn = partial(update, model, opt) if use_scalify: - # Transform parameters to `ScaledArray` and proper dtype. + # Transform parameters to `ScaledArray`. model_state = jsa.as_scaled_array(model_state, scale=scale_dtype(1.0)) opt_state = jsa.as_scaled_array(opt_state, scale=scale_dtype(0.0001)) - - model_state = jax.tree_util.tree_map( - lambda v: v.astype(training_dtype), model_state, is_leaf=jsa.core.is_scaled_leaf - ) # Scalify the update function as well. update_fn = jsa.scalify(update_fn) - else: - model_state = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), model_state) + # Convert the model state (weights) & optimizer state to proper dtype. + model_state = jsa.tree.astype(model_state, training_dtype) + opt_state = jsa.tree.astype(opt_state, optimizer_dtype, floating_only=True) print(f"Using Scalify: {use_scalify}") print(f"Training data format: {training_dtype.name}") - # print(f"Optimizer data format: {training_dtype.name}") + print(f"Optimizer data format: {optimizer_dtype.name}") print("") update_fn = jax.jit(update_fn) @@ -134,7 +131,7 @@ def data_stream(): for _ in range(num_batches): batch = next(batches) # Scaled micro-batch + training dtype cast. - batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch) + batch = jsa.tree.astype(batch, training_dtype) if use_scalify: batch = jsa.as_scaled_array(batch, scale=scale_dtype(1)) with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): diff --git a/jax_scalify/core/datatype.py b/jax_scalify/core/datatype.py index dba692f..f61f5c6 100644 --- a/jax_scalify/core/datatype.py +++ b/jax_scalify/core/datatype.py @@ -132,10 +132,13 @@ def make_scaled_scalar(val: Array, scale_dtype: Optional[DTypeLike] = None) -> S def is_scaled_leaf(val: Any) -> bool: - """Is input a JAX PyTree (scaled) leaf, including ScaledArray. + """Is input a normal JAX PyTree leaf (i.e. `Array`) or `ScaledArray1. - This function is useful for JAX PyTree handling where the user wants - to keep the ScaledArray datastructures (i.e. not flattened as a pair of arrays). + This function is useful for JAX PyTree handling with `jax.tree` methods where + the user wants to keep the ScaledArray data structures (i.e. not flattened as a + pair of arrays). + + See `jax_scalify.tree` for PyTree `jax.tree` methods compatible with `ScaledArray`. """ # TODO: check Numpy scalars as well? return np.isscalar(val) or isinstance(val, (Array, np.ndarray, ScaledArray))