Skip to content

Commit

Permalink
Support JAX debug callback on ScaledArrays (#56)
Browse files Browse the repository at this point in the history
Intermediate solution at the moment due to JAX debug callback limitations.
It is not possible at present to get the original callback + input PyTree from
the JAX debug_callback primitive, meaning the interpreter can not modify properly
the graph.

In this PR, we re-implement the high level API `debug_callback` of JAX in order
to save the former, and be able to adapt the call to ScaledArrays.
  • Loading branch information
balancap authored Dec 17, 2023
1 parent 318bd41 commit ea8c961
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 2 deletions.
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from . import lax
from ._version import __version__
from .core import ScaledArray, as_scaled_array, asarray, autoscale, scaled_array # noqa: F401
from .core import ScaledArray, as_scaled_array, asarray, autoscale, debug_callback, scaled_array # noqa: F401
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
is_static_zero,
scaled_array,
)
from .debug import debug_callback # noqa: F401
from .interpreters import ( # noqa: F401
ScaledPrimitiveType,
autoscale,
Expand Down
62 changes: 62 additions & 0 deletions jax_scaled_arithmetics/core/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any, Callable

from jax import tree_util
from jax._src.debugging import debug_callback as debug_callback_orig
from jax._src.debugging import debug_callback_p

from .interpreters import register_scaled_op


def get_debug_callback_effect(ordered: bool) -> Any:
"""Backward compatible effect factory method."""
try:
from jax._src.debugging import debug_effect, ordered_debug_effect

return ordered_debug_effect if ordered else debug_effect
except ImportError:
from jax._src.debugging import DebugEffect

return DebugEffect.ORDERED_PRINT if ordered else DebugEffect.PRINT


def debug_callback(callback: Callable[..., Any], *args: Any, ordered: bool = False, **kwargs: Any) -> None:
# We need our custom version of `debug_callback` to deal with
# changing JAX pytrees.
# FIXME: probably patch `debug_callback` in JAX.
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
effect = get_debug_callback_effect(ordered)

def _flat_callback(*flat_args):
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
callback(*args, **kwargs)
return []

# Storing in original PyTree and callback function.
# Allowing custom interpreters to retrieve and modify this information.
_flat_callback.__callback_fn = callback # type:ignore
_flat_callback.__callback_in_tree = in_tree # type:ignore
debug_callback_p.bind(*flat_args, callback=_flat_callback, effect=effect)


debug_callback.__doc__ = debug_callback_orig.__doc__


def scaled_debug_callback(*args, **params) -> Any:
"""Scaled `debug_callback`: properly forwarding ScaledArrays
to host callback.
"""
flat_callback_fn = params["callback"]
if not hasattr(flat_callback_fn, "__callback_fn"):
raise NotImplementedError("Please use `jsa.debug_callback` function instead of original JAX function.")
callback_fn = flat_callback_fn.__callback_fn
in_pytree = flat_callback_fn.__callback_in_tree
# Re-build original input, with scaled arrays.
scaled_args, scaled_kwargs = tree_util.tree_unflatten(in_pytree, args)
# Re-build ordered boolean, in a backward compatible way.
ordered = "ordered" in str(params["effect"]).lower()
debug_callback(callback_fn, *scaled_args, ordered=ordered, **scaled_kwargs)
return []


register_scaled_op(debug_callback_p, scaled_debug_callback)
32 changes: 31 additions & 1 deletion tests/lax/test_scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
from absl.testing import parameterized
from jax import lax

from jax_scaled_arithmetics.core import Array, ScaledArray, find_registered_scaled_op, scaled_array
from jax_scaled_arithmetics.core import (
Array,
ScaledArray,
autoscale,
debug_callback,
find_registered_scaled_op,
scaled_array,
)
from jax_scaled_arithmetics.lax import (
scaled_broadcast_in_dim,
scaled_concatenate,
Expand All @@ -31,6 +38,29 @@ def setUp(self):
# Use random state for reproducibility!
self.rs = np.random.RandomState(42)

@chex.variants(with_jit=True, without_jit=True)
def test__scaled_debug_callback__proper_forwarding(self):
host_values = []

def callback(*args):
for v in args:
host_values.append(v)

def fn(a):
debug_callback(callback, a, a * 3)
return a

x = scaled_array(self.rs.rand(5), 2, dtype=np.float16)
fn = self.variant(autoscale(fn))
fn(x)

assert len(host_values) == 2
for sv in host_values:
assert isinstance(sv, ScaledArray)
npt.assert_array_equal(sv.data, x.data)
npt.assert_array_equal(host_values[0].scale, x.scale)
npt.assert_array_equal(host_values[1].scale, x.scale * 3)

def test__scaled_broadcast_in_dim__proper_scaling(self):
x = scaled_array(self.rs.rand(5), 2, dtype=np.float32)
z = scaled_broadcast_in_dim(x, shape=(5, 1), broadcast_dimensions=(0,))
Expand Down

0 comments on commit ea8c961

Please sign in to comment.