diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index b0948c1..6a83a1b 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -6,7 +6,12 @@ import jax import numpy as np from jax import core -from jax._src.custom_derivatives import custom_jvp_call_jaxpr_p, custom_jvp_call_p, custom_vjp_call_p +from jax._src.custom_derivatives import ( + custom_jvp_call_jaxpr_p, + custom_jvp_call_p, + custom_vjp_call_jaxpr_p, + custom_vjp_call_p, +) from jax._src.util import safe_map from .datatype import NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf @@ -287,7 +292,12 @@ def scaled_custom_vjp_call_translation(*args: ScaledArray, **params: Any) -> Seq """Scaled translation of `custom_vjp_call` primitive. Forwarding the scaled call to sub-jaxpr, and modifying the underlying `vjp` function. """ - raise NotImplementedError("Scaled custom VJP primitive not yet supported.") + key_jaxpr = "fun_jaxpr" + call_closed_jaxpr = params[key_jaxpr] + # FIXME: re-call the custom_vjp decorator/bind. + call_subfunc = partial(autoscale_jaxpr, call_closed_jaxpr.jaxpr, call_closed_jaxpr.literals) + return call_subfunc(*args) register_scaled_op(custom_vjp_call_p, scaled_custom_vjp_call_translation) +register_scaled_op(custom_vjp_call_jaxpr_p, scaled_custom_vjp_call_translation) diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index b24092b..06d87a1 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -171,6 +171,42 @@ def fn(x, y): npt.assert_array_almost_equal(scaled_primals, primals) npt.assert_array_almost_equal(scaled_tangents, tangents) + @chex.variants(with_jit=False, without_jit=True) + def test__autoscale_decorator__custom_vjp__proper_graph_transformation_and_result(self): + # JAX official `vjp` example. + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y + + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd) + + def fn(x, y): + primals, f_vjp = jax.vjp(f, x, y) + return primals, f_vjp(x * y) + + # `autoscale` on `custom_jvp` method. + scaled_inputs = ( + scaled_array([-2.0, 0.5], 0.5, dtype=np.float32), + scaled_array([1.5, -4.5], 2, dtype=np.float32), + ) + scaled_primals, scaled_grads = self.variant(autoscale(fn))(*scaled_inputs) + # JAX default/expected values + inputs = tuple(map(asarray, scaled_inputs)) + primals, grads = self.variant(fn)(*inputs) + + assert isinstance(scaled_primals, ScaledArray) + npt.assert_array_almost_equal(scaled_primals, primals) + for g, sg in zip(grads, scaled_grads): + assert isinstance(sg, ScaledArray) + npt.assert_array_almost_equal(sg, g) + @parameterized.parameters( {"input": 3.0}, {"input": np.float32(3.0)},