Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify JAX Scalify MNIST examples using jax_scalify.tree methods. #118

Merged
merged 1 commit into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def data_stream():
params = init_random_params(param_scale, layer_sizes)
# Transform parameters to `ScaledArray` and proper dtype.
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)
params = jsa.tree.astype(params, training_dtype)

@jit
@jsa.scalify
Expand All @@ -119,7 +119,7 @@ def update(params, batch):
batch = next(batches)
# Scaled micro-batch + training dtype cast.
batch = jsa.as_scaled_array(batch, scale=scale_dtype(1))
batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)
batch = jsa.tree.astype(batch, training_dtype)

with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
params = update(params, batch)
Expand Down
5 changes: 2 additions & 3 deletions examples/mnist/mnist_classifier_from_scratch_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import time

import datasets
import jax
import jax.numpy as jnp
import ml_dtypes
import numpy as np
Expand Down Expand Up @@ -133,7 +132,7 @@ def data_stream():
params = init_random_params(param_scale, layer_sizes)
# Transform parameters to `ScaledArray` and proper dtype.
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)
params = jsa.tree.astype(params, training_dtype)

@jit
@jsa.scalify
Expand All @@ -147,7 +146,7 @@ def update(params, batch):
batch = next(batches)
# Scaled micro-batch + training dtype cast.
batch = jsa.as_scaled_array(batch, scale=scale_dtype(1))
batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)
batch = jsa.tree.astype(batch, training_dtype)

with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
params = update(params, batch)
Expand Down
19 changes: 8 additions & 11 deletions examples/mnist/mnist_classifier_mlp_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def update(model, optimizer, model_state, opt_state, batch):
key = jax.random.PRNGKey(42)
use_scalify: bool = True

# training_dtype = np.dtype(np.float16)
training_dtype = np.dtype(np.float16)
optimizer_dtype = np.dtype(np.float16)
scale_dtype = np.float32

train_images, train_labels, test_images, test_labels = datasets.mnist()
Expand All @@ -102,27 +102,24 @@ def data_stream():
model_state = model.init(key, np.zeros((batch_size, mnist_img_size), dtype=training_dtype))
# Optimizer & optimizer state.
# opt = optax.sgd(learning_rate=step_size)
opt = optax.adam(learning_rate=step_size, eps=1e-5)
opt = optax.adam(learning_rate=step_size, eps=2**-16)
opt_state = opt.init(model_state)
# Freeze model, optimizer (with step size).
update_fn = partial(update, model, opt)

if use_scalify:
# Transform parameters to `ScaledArray` and proper dtype.
# Transform parameters to `ScaledArray`.
model_state = jsa.as_scaled_array(model_state, scale=scale_dtype(1.0))
opt_state = jsa.as_scaled_array(opt_state, scale=scale_dtype(0.0001))

model_state = jax.tree_util.tree_map(
lambda v: v.astype(training_dtype), model_state, is_leaf=jsa.core.is_scaled_leaf
)
# Scalify the update function as well.
update_fn = jsa.scalify(update_fn)
else:
model_state = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), model_state)
# Convert the model state (weights) & optimizer state to proper dtype.
model_state = jsa.tree.astype(model_state, training_dtype)
opt_state = jsa.tree.astype(opt_state, optimizer_dtype, floating_only=True)

print(f"Using Scalify: {use_scalify}")
print(f"Training data format: {training_dtype.name}")
# print(f"Optimizer data format: {training_dtype.name}")
print(f"Optimizer data format: {optimizer_dtype.name}")
print("")

update_fn = jax.jit(update_fn)
Expand All @@ -134,7 +131,7 @@ def data_stream():
for _ in range(num_batches):
batch = next(batches)
# Scaled micro-batch + training dtype cast.
batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch)
batch = jsa.tree.astype(batch, training_dtype)
if use_scalify:
batch = jsa.as_scaled_array(batch, scale=scale_dtype(1))
with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
Expand Down
9 changes: 6 additions & 3 deletions jax_scalify/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,13 @@ def make_scaled_scalar(val: Array, scale_dtype: Optional[DTypeLike] = None) -> S


def is_scaled_leaf(val: Any) -> bool:
"""Is input a JAX PyTree (scaled) leaf, including ScaledArray.
"""Is input a normal JAX PyTree leaf (i.e. `Array`) or `ScaledArray1.

This function is useful for JAX PyTree handling where the user wants
to keep the ScaledArray datastructures (i.e. not flattened as a pair of arrays).
This function is useful for JAX PyTree handling with `jax.tree` methods where
the user wants to keep the ScaledArray data structures (i.e. not flattened as a
pair of arrays).

See `jax_scalify.tree` for PyTree `jax.tree` methods compatible with `ScaledArray`.
"""
# TODO: check Numpy scalars as well?
return np.isscalar(val) or isinstance(val, (Array, np.ndarray, ScaledArray))
Expand Down
Loading