diff --git a/experiments/mnist/mnist_classifier_from_scratch_fp8.py b/experiments/mnist/mnist_classifier_from_scratch_fp8.py new file mode 100644 index 0000000..d0acfbb --- /dev/null +++ b/experiments/mnist/mnist_classifier_from_scratch_fp8.py @@ -0,0 +1,145 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A basic MNIST example using Numpy and JAX. + +The primary aim here is simplicity and minimal dependencies. +""" + + +import time + +import datasets +import jax +import jax.numpy as jnp +import ml_dtypes +import numpy as np +import numpy.random as npr +from jax import grad, jit, lax + +import jax_scaled_arithmetics as jsa + +# from jax.scipy.special import logsumexp + + +def logsumexp(a, axis=None, keepdims=False): + dims = (axis,) + amax = jnp.max(a, axis=dims, keepdims=keepdims) + # FIXME: not proper scale propagation, introducing NaNs + # amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) + amax = lax.stop_gradient(amax) + out = lax.sub(a, amax) + out = lax.exp(out) + out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax) + return out + + +def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): + return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] + + +def predict(params, inputs): + activations = inputs + for w, b in params[:-1]: + # Forward FP8 casting. + w = jsa.ops.cast_ml_dtype(w, ml_dtypes.float8_e4m3fn) + activations = jsa.ops.cast_ml_dtype(activations, ml_dtypes.float8_e4m3fn) + # Matmul + outputs = jnp.dot(activations, w) + # Backward FP8 casting + outputs = jsa.ops.cast_ml_dtype_grad(outputs, ml_dtypes.float8_e5m2) + + # Bias + relu + outputs = outputs + b + activations = jnp.maximum(outputs, 0) + + final_w, final_b = params[-1] + # final_w = jsa.ops.cast_ml_dtype(final_w, ml_dtypes.float8_e4m3fn) + # activations = jsa.ops.cast_ml_dtype(activations, ml_dtypes.float8_e4m3fn) + logits = jnp.dot(activations, final_w) + final_b + + # Dynamic rescaling of the gradient, as logits gradient not properly scaled. + logits = jsa.ops.dynamic_rescale_l2_grad(logits) + logits = logits - logsumexp(logits, axis=1, keepdims=True) + return logits + + +def loss(params, batch): + inputs, targets = batch + preds = predict(params, inputs) + return -jnp.mean(jnp.sum(preds * targets, axis=1)) + + +def accuracy(params, batch): + inputs, targets = batch + target_class = jnp.argmax(targets, axis=1) + predicted_class = jnp.argmax(predict(params, inputs), axis=1) + return jnp.mean(predicted_class == target_class) + + +if __name__ == "__main__": + layer_sizes = [784, 1024, 1024, 10] + param_scale = 1.0 + step_size = 0.001 + num_epochs = 10 + batch_size = 128 + + training_dtype = np.float16 + scale_dtype = np.float32 + + train_images, train_labels, test_images, test_labels = datasets.mnist() + num_train = train_images.shape[0] + num_complete_batches, leftover = divmod(num_train, batch_size) + num_batches = num_complete_batches + bool(leftover) + + def data_stream(): + rng = npr.RandomState(0) + while True: + perm = rng.permutation(num_train) + for i in range(num_batches): + batch_idx = perm[i * batch_size : (i + 1) * batch_size] + yield train_images[batch_idx], train_labels[batch_idx] + + batches = data_stream() + params = init_random_params(param_scale, layer_sizes) + # Transform parameters to `ScaledArray` and proper dtype. + params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) + params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) + + @jit + @jsa.autoscale + def update(params, batch): + grads = grad(loss)(params, batch) + return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] + + for epoch in range(num_epochs): + start_time = time.time() + for _ in range(num_batches): + batch = next(batches) + # Scaled micro-batch + training dtype cast. + batch = jsa.as_scaled_array(batch, scale=scale_dtype(1)) + batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) + + with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): + params = update(params, batch) + + epoch_time = time.time() - start_time + + # Evaluation in float32, for consistency. + raw_params = jsa.asarray(params, dtype=np.float32) + train_acc = accuracy(raw_params, (train_images, train_labels)) + test_acc = accuracy(raw_params, (test_images, test_labels)) + print(f"Epoch {epoch} in {epoch_time:0.2f} sec") + print(f"Training set accuracy {train_acc:0.5f}") + print(f"Test set accuracy {test_acc:0.5f}") diff --git a/jax_scaled_arithmetics/ops/__init__.py b/jax_scaled_arithmetics/ops/__init__.py index deb516d..ea132b8 100644 --- a/jax_scaled_arithmetics/ops/__init__.py +++ b/jax_scaled_arithmetics/ops/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from .debug import debug_callback, debug_callback_grad, debug_print, debug_print_grad # noqa: F401 +from .ml_dtypes import cast_ml_dtype, cast_ml_dtype_grad # noqa: F401 from .rescaling import ( # noqa: F401 dynamic_rescale_l1, dynamic_rescale_l1_grad, diff --git a/jax_scaled_arithmetics/ops/ml_dtypes.py b/jax_scaled_arithmetics/ops/ml_dtypes.py new file mode 100644 index 0000000..940dcdd --- /dev/null +++ b/jax_scaled_arithmetics/ops/ml_dtypes.py @@ -0,0 +1,25 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from functools import partial + +import jax +import ml_dtypes + +from jax_scaled_arithmetics.core import Array, DTypeLike + +from .rescaling import fn_bwd_identity_fwd, fn_fwd_identity_bwd + + +def cast_ml_dtype_base(arr: Array, dtype: DTypeLike) -> Array: + """`Fake` cast to an ML dtype (e.g. FP8), using JAX LAX `reduce_precision` operator.""" + info = ml_dtypes.finfo(dtype) + return jax.lax.reduce_precision(arr, exponent_bits=info.nexp, mantissa_bits=info.nmant) + + +def cast_ml_dtype(arr: Array, dtype: DTypeLike) -> Array: + """`Fake` cast to an ML dtype, on the forward pass (no-op on backward pass).""" + return partial(fn_fwd_identity_bwd, lambda v: cast_ml_dtype_base(v, dtype))(arr) + + +def cast_ml_dtype_grad(arr: Array, dtype: DTypeLike) -> Array: + """`Fake` cast to an ML dtype on the backward pass (no-op on forward pass).""" + return partial(fn_bwd_identity_fwd, lambda v: cast_ml_dtype_base(v, dtype))(arr) diff --git a/jax_scaled_arithmetics/ops/rescaling.py b/jax_scaled_arithmetics/ops/rescaling.py index d3a7b4e..f2e0325 100644 --- a/jax_scaled_arithmetics/ops/rescaling.py +++ b/jax_scaled_arithmetics/ops/rescaling.py @@ -9,37 +9,37 @@ @partial(jax.custom_vjp, nondiff_argnums=(0,)) -def fn_with_identity_grad(f, arg): +def fn_fwd_identity_bwd(f, arg): """Function with identity bwd/grad.""" return f(arg) -def fn_with_identity_grad_fwd(f, arg): +def fn_fwd_identity_bwd_fwd(f, arg): return arg, None -def fn_with_identity_grad_bwd(f, _, grad): +def fn_fwd_identity_bwd_bwd(f, _, grad): return (grad,) -fn_with_identity_grad.defvjp(fn_with_identity_grad_fwd, fn_with_identity_grad_bwd) +fn_fwd_identity_bwd.defvjp(fn_fwd_identity_bwd_fwd, fn_fwd_identity_bwd_bwd) @partial(jax.custom_vjp, nondiff_argnums=(0,)) -def fn_on_grad(f, arg): +def fn_bwd_identity_fwd(f, arg): """Apply a function on the gradient/backward pass.""" return arg -def fn_on_grad_fwd(f, arg): +def fn_bwd_identity_fwd_fwd(f, arg): return arg, None -def fn_on_grad_bwd(f, _, grad): +def fn_bwd_identity_fwd_bwd(f, _, grad): return (f(grad),) -fn_on_grad.defvjp(fn_on_grad_fwd, fn_on_grad_bwd) +fn_bwd_identity_fwd.defvjp(fn_bwd_identity_fwd_fwd, fn_bwd_identity_fwd_bwd) def dynamic_rescale_max_base(arr: ScaledArray) -> ScaledArray: @@ -97,11 +97,11 @@ def dynamic_rescale_l2_base(arr: ScaledArray) -> ScaledArray: # Dynamic rescale on fwd arrays. -dynamic_rescale_max = partial(fn_with_identity_grad, dynamic_rescale_max_base) -dynamic_rescale_l1 = partial(fn_with_identity_grad, dynamic_rescale_l1_base) -dynamic_rescale_l2 = partial(fn_with_identity_grad, dynamic_rescale_l2_base) +dynamic_rescale_max = partial(fn_fwd_identity_bwd, dynamic_rescale_max_base) +dynamic_rescale_l1 = partial(fn_fwd_identity_bwd, dynamic_rescale_l1_base) +dynamic_rescale_l2 = partial(fn_fwd_identity_bwd, dynamic_rescale_l2_base) # Dynamic rescale on gradients. -dynamic_rescale_max_grad = partial(fn_on_grad, dynamic_rescale_max_base) -dynamic_rescale_l1_grad = partial(fn_on_grad, dynamic_rescale_l1_base) -dynamic_rescale_l2_grad = partial(fn_on_grad, dynamic_rescale_l2_base) +dynamic_rescale_max_grad = partial(fn_bwd_identity_fwd, dynamic_rescale_max_base) +dynamic_rescale_l1_grad = partial(fn_bwd_identity_fwd, dynamic_rescale_l1_base) +dynamic_rescale_l2_grad = partial(fn_bwd_identity_fwd, dynamic_rescale_l2_base) diff --git a/pyproject.toml b/pyproject.toml index 4777f64..2e859eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "chex >= 0.1.6", "jax >= 0.3.16", "jaxlib >= 0.3.15", + "ml_dtypes", "numpy >= 1.22.4" ] diff --git a/tests/ops/test_ml_dtypes.py b/tests/ops/test_ml_dtypes.py new file mode 100644 index 0000000..1e3dc3b --- /dev/null +++ b/tests/ops/test_ml_dtypes.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from functools import partial + +import chex +import ml_dtypes +import numpy as np +import numpy.testing as npt +from absl.testing import parameterized + +from jax_scaled_arithmetics.core import autoscale, scaled_array +from jax_scaled_arithmetics.ops import cast_ml_dtype + + +class CastMLDtypeTests(chex.TestCase): + @parameterized.parameters( + {"ml_dtype": ml_dtypes.float8_e4m3fn}, + {"ml_dtype": ml_dtypes.float8_e5m2}, + ) + def test__cast_ml_dtype__consistent_rounding_down(self, ml_dtype): + # Values potentially "problematic" in FP8. + values = np.array([17, -17, 8, 1, 9, 11, 18], np.float16) + out = cast_ml_dtype(values, dtype=ml_dtype) + expected_out = values.astype(ml_dtype) + assert out.dtype == values.dtype + npt.assert_array_equal(out, expected_out) + + @parameterized.parameters( + {"ml_dtype": ml_dtypes.float8_e4m3fn}, + {"ml_dtype": ml_dtypes.float8_e5m2}, + ) + def test__cast_ml_dtype__autoscale_compatiblity(self, ml_dtype): + values = np.array([17, -17, 8, 1, 9, 11, 18], np.float16) + arr = scaled_array(values, np.float32(1)) + out = autoscale(partial(cast_ml_dtype, dtype=ml_dtype))(arr) + + npt.assert_array_equal(out.scale, arr.scale) + npt.assert_array_equal(out, np.asarray(arr.data).astype(ml_dtype))