From ba6115a42528c5571e9f8a543bde37b8e4f11d80 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 24 Nov 2023 17:08:04 +0000 Subject: [PATCH] `autoscale` supporting `custom_jvp` decorator and primitive. (#34) Forwarding ScaledArray inside `custom_jvp` calls. --- jax_scaled_arithmetics/core/interpreters.py | 28 +++++++++++++- jax_scaled_arithmetics/lax/scaled_ops.py | 10 +++++ tests/core/test_interpreter.py | 43 ++++++++++++++++++++- 3 files changed, 79 insertions(+), 2 deletions(-) diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 4010c94..9a26867 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -6,6 +6,7 @@ import jax import numpy as np from jax import core +from jax._src.custom_derivatives import custom_jvp_call_p, custom_vjp_call_p from jax._src.pjit import pjit_p from jax._src.util import safe_map @@ -190,7 +191,8 @@ def promote_to_scaled_array(val): if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE: # Using normal JAX primitive: no scaled inputs, and not always scale rule. - outvals = eqn.primitive.bind(*invals, **eqn.params) + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params) elif scaled_prim_fn is None: raise NotImplementedError( f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet." @@ -232,3 +234,27 @@ def scaled_pjit_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[Scale register_scaled_op(pjit_p, scaled_pjit_translation) + + +def scaled_custom_jvp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]: + """Scaled translation of `custom_jvp_call` primitive. Forwarding the scaled call to sub-jaxpr, + and modifying the underlying `jvp` function. + """ + # [fun, jvp], bind_params = custom_jvp_call_p.get_bind_params(params) + call_closed_jaxpr = params["call_jaxpr"] + # FIXME: re-call the custom_jvp decorator/bind. + call_subfunc = partial(autoscale_jaxpr, call_closed_jaxpr.jaxpr, call_closed_jaxpr.literals) + return call_subfunc(*args) + + +register_scaled_op(custom_jvp_call_p, scaled_custom_jvp_call_translation) + + +def scaled_custom_vjp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]: + """Scaled translation of `custom_vjp_call` primitive. Forwarding the scaled call to sub-jaxpr, + and modifying the underlying `vjp` function. + """ + raise NotImplementedError("Scaled custom VJP primitive not yet supported.") + + +register_scaled_op(custom_vjp_call_p, scaled_custom_vjp_call_translation) diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index af2e4c5..badd160 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -264,3 +264,13 @@ def scaled_log(val: ScaledArray) -> ScaledArray: @core.register_scaled_lax_op def scaled_select_n(which: jax.Array, *cases: ScaledArray) -> ScaledArray: return scaled_op_default_translation(lax.select_n_p, [which, *cases]) + + +@core.register_scaled_lax_op +def scaled_cos(val: ScaledArray) -> ScaledArray: + return scaled_op_default_translation(lax.cos_p, [val]) + + +@core.register_scaled_lax_op +def scaled_sin(val: ScaledArray) -> ScaledArray: + return scaled_op_default_translation(lax.sin_p, [val]) diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index 896f97a..6ed76c4 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -1,12 +1,20 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. import chex import jax +import jax.lax import jax.numpy as jnp import numpy as np import numpy.testing as npt from absl.testing import parameterized -from jax_scaled_arithmetics.core import ScaledArray, autoscale, is_scaled_leaf, register_scaled_op, scaled_array +from jax_scaled_arithmetics.core import ( + ScaledArray, + asarray, + autoscale, + is_scaled_leaf, + register_scaled_op, + scaled_array, +) from jax_scaled_arithmetics.core.interpreters import promote_scalar_to_scaled_array @@ -129,6 +137,39 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn, assert scaled_out.dtype == exp_out.dtype npt.assert_array_almost_equal(scaled_out, exp_out, decimal=4) + @chex.variants(with_jit=True, without_jit=True) + def test__autoscale_decorator__custom_jvp__proper_graph_transformation_and_result(self): + # JAX official `jvp` example. + @jax.custom_jvp + def f(x, y): + return jnp.sin(x) * y + + @f.defjvp + def f_jvp(primals, tangents): + x, y = primals + x_dot, y_dot = tangents + primal_out = f(x, y) + tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot + return primal_out, tangent_out + + def fn(x, y): + return jax.jvp(f, (x, y), (x, y)) + + # `autoscale` on `custom_jvp` method. + scaled_inputs = ( + scaled_array([-2.0, 0.5], 0.5, dtype=np.float32), + scaled_array([1.5, -4.5], 2, dtype=np.float32), + ) + scaled_primals, scaled_tangents = self.variant(autoscale(fn))(*scaled_inputs) + # JAX default/expected values + inputs = tuple(map(asarray, scaled_inputs)) + primals, tangents = self.variant(fn)(*inputs) + + assert isinstance(scaled_primals, ScaledArray) + assert isinstance(scaled_tangents, ScaledArray) + npt.assert_array_almost_equal(scaled_primals, primals) + npt.assert_array_almost_equal(scaled_tangents, tangents) + @parameterized.parameters( {"input": np.array(3)}, {"input": jnp.array(3)},