From ea8c96109a51bec4f4b68118f0ded01a62fa9654 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 17 Dec 2023 22:20:13 +0000 Subject: [PATCH] Support JAX debug callback on ScaledArrays (#56) Intermediate solution at the moment due to JAX debug callback limitations. It is not possible at present to get the original callback + input PyTree from the JAX debug_callback primitive, meaning the interpreter can not modify properly the graph. In this PR, we re-implement the high level API `debug_callback` of JAX in order to save the former, and be able to adapt the call to ScaledArrays. --- jax_scaled_arithmetics/__init__.py | 2 +- jax_scaled_arithmetics/core/__init__.py | 1 + jax_scaled_arithmetics/core/debug.py | 62 +++++++++++++++++++++++++ tests/lax/test_scaled_ops.py | 32 ++++++++++++- 4 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 jax_scaled_arithmetics/core/debug.py diff --git a/jax_scaled_arithmetics/__init__.py b/jax_scaled_arithmetics/__init__.py index d9ba269..156a298 100644 --- a/jax_scaled_arithmetics/__init__.py +++ b/jax_scaled_arithmetics/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from . import lax from ._version import __version__ -from .core import ScaledArray, as_scaled_array, asarray, autoscale, scaled_array # noqa: F401 +from .core import ScaledArray, as_scaled_array, asarray, autoscale, debug_callback, scaled_array # noqa: F401 diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scaled_arithmetics/core/__init__.py index ee950a2..8cd6bc5 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scaled_arithmetics/core/__init__.py @@ -10,6 +10,7 @@ is_static_zero, scaled_array, ) +from .debug import debug_callback # noqa: F401 from .interpreters import ( # noqa: F401 ScaledPrimitiveType, autoscale, diff --git a/jax_scaled_arithmetics/core/debug.py b/jax_scaled_arithmetics/core/debug.py new file mode 100644 index 0000000..c06f7a0 --- /dev/null +++ b/jax_scaled_arithmetics/core/debug.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from typing import Any, Callable + +from jax import tree_util +from jax._src.debugging import debug_callback as debug_callback_orig +from jax._src.debugging import debug_callback_p + +from .interpreters import register_scaled_op + + +def get_debug_callback_effect(ordered: bool) -> Any: + """Backward compatible effect factory method.""" + try: + from jax._src.debugging import debug_effect, ordered_debug_effect + + return ordered_debug_effect if ordered else debug_effect + except ImportError: + from jax._src.debugging import DebugEffect + + return DebugEffect.ORDERED_PRINT if ordered else DebugEffect.PRINT + + +def debug_callback(callback: Callable[..., Any], *args: Any, ordered: bool = False, **kwargs: Any) -> None: + # We need our custom version of `debug_callback` to deal with + # changing JAX pytrees. + # FIXME: probably patch `debug_callback` in JAX. + flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) + effect = get_debug_callback_effect(ordered) + + def _flat_callback(*flat_args): + args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) + callback(*args, **kwargs) + return [] + + # Storing in original PyTree and callback function. + # Allowing custom interpreters to retrieve and modify this information. + _flat_callback.__callback_fn = callback # type:ignore + _flat_callback.__callback_in_tree = in_tree # type:ignore + debug_callback_p.bind(*flat_args, callback=_flat_callback, effect=effect) + + +debug_callback.__doc__ = debug_callback_orig.__doc__ + + +def scaled_debug_callback(*args, **params) -> Any: + """Scaled `debug_callback`: properly forwarding ScaledArrays + to host callback. + """ + flat_callback_fn = params["callback"] + if not hasattr(flat_callback_fn, "__callback_fn"): + raise NotImplementedError("Please use `jsa.debug_callback` function instead of original JAX function.") + callback_fn = flat_callback_fn.__callback_fn + in_pytree = flat_callback_fn.__callback_in_tree + # Re-build original input, with scaled arrays. + scaled_args, scaled_kwargs = tree_util.tree_unflatten(in_pytree, args) + # Re-build ordered boolean, in a backward compatible way. + ordered = "ordered" in str(params["effect"]).lower() + debug_callback(callback_fn, *scaled_args, ordered=ordered, **scaled_kwargs) + return [] + + +register_scaled_op(debug_callback_p, scaled_debug_callback) diff --git a/tests/lax/test_scaled_ops.py b/tests/lax/test_scaled_ops.py index afa106a..587bc1d 100644 --- a/tests/lax/test_scaled_ops.py +++ b/tests/lax/test_scaled_ops.py @@ -5,7 +5,14 @@ from absl.testing import parameterized from jax import lax -from jax_scaled_arithmetics.core import Array, ScaledArray, find_registered_scaled_op, scaled_array +from jax_scaled_arithmetics.core import ( + Array, + ScaledArray, + autoscale, + debug_callback, + find_registered_scaled_op, + scaled_array, +) from jax_scaled_arithmetics.lax import ( scaled_broadcast_in_dim, scaled_concatenate, @@ -31,6 +38,29 @@ def setUp(self): # Use random state for reproducibility! self.rs = np.random.RandomState(42) + @chex.variants(with_jit=True, without_jit=True) + def test__scaled_debug_callback__proper_forwarding(self): + host_values = [] + + def callback(*args): + for v in args: + host_values.append(v) + + def fn(a): + debug_callback(callback, a, a * 3) + return a + + x = scaled_array(self.rs.rand(5), 2, dtype=np.float16) + fn = self.variant(autoscale(fn)) + fn(x) + + assert len(host_values) == 2 + for sv in host_values: + assert isinstance(sv, ScaledArray) + npt.assert_array_equal(sv.data, x.data) + npt.assert_array_equal(host_values[0].scale, x.scale) + npt.assert_array_equal(host_values[1].scale, x.scale * 3) + def test__scaled_broadcast_in_dim__proper_scaling(self): x = scaled_array(self.rs.rand(5), 2, dtype=np.float32) z = scaled_broadcast_in_dim(x, shape=(5, 1), broadcast_dimensions=(0,))