Skip to content

Commit

Permalink
Support custom_vjp JAX primitive in AutoScale. (#51)
Browse files Browse the repository at this point in the history
Basic support. TODO: properly wrap in `custom_vjp` call.
  • Loading branch information
balancap authored Dec 12, 2023
1 parent 9d02472 commit a3b3304
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
14 changes: 12 additions & 2 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import jax
import numpy as np
from jax import core
from jax._src.custom_derivatives import custom_jvp_call_jaxpr_p, custom_jvp_call_p, custom_vjp_call_p
from jax._src.custom_derivatives import (
custom_jvp_call_jaxpr_p,
custom_jvp_call_p,
custom_vjp_call_jaxpr_p,
custom_vjp_call_p,
)
from jax._src.util import safe_map

from .datatype import NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf
Expand Down Expand Up @@ -287,7 +292,12 @@ def scaled_custom_vjp_call_translation(*args: ScaledArray, **params: Any) -> Seq
"""Scaled translation of `custom_vjp_call` primitive. Forwarding the scaled call to sub-jaxpr,
and modifying the underlying `vjp` function.
"""
raise NotImplementedError("Scaled custom VJP primitive not yet supported.")
key_jaxpr = "fun_jaxpr"
call_closed_jaxpr = params[key_jaxpr]
# FIXME: re-call the custom_vjp decorator/bind.
call_subfunc = partial(autoscale_jaxpr, call_closed_jaxpr.jaxpr, call_closed_jaxpr.literals)
return call_subfunc(*args)


register_scaled_op(custom_vjp_call_p, scaled_custom_vjp_call_translation)
register_scaled_op(custom_vjp_call_jaxpr_p, scaled_custom_vjp_call_translation)
36 changes: 36 additions & 0 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,42 @@ def fn(x, y):
npt.assert_array_almost_equal(scaled_primals, primals)
npt.assert_array_almost_equal(scaled_tangents, tangents)

@chex.variants(with_jit=False, without_jit=True)
def test__autoscale_decorator__custom_vjp__proper_graph_transformation_and_result(self):
# JAX official `vjp` example.
@jax.custom_vjp
def f(x, y):
return jnp.sin(x) * y

def f_fwd(x, y):
return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)

def fn(x, y):
primals, f_vjp = jax.vjp(f, x, y)
return primals, f_vjp(x * y)

# `autoscale` on `custom_jvp` method.
scaled_inputs = (
scaled_array([-2.0, 0.5], 0.5, dtype=np.float32),
scaled_array([1.5, -4.5], 2, dtype=np.float32),
)
scaled_primals, scaled_grads = self.variant(autoscale(fn))(*scaled_inputs)
# JAX default/expected values
inputs = tuple(map(asarray, scaled_inputs))
primals, grads = self.variant(fn)(*inputs)

assert isinstance(scaled_primals, ScaledArray)
npt.assert_array_almost_equal(scaled_primals, primals)
for g, sg in zip(grads, scaled_grads):
assert isinstance(sg, ScaledArray)
npt.assert_array_almost_equal(sg, g)

@parameterized.parameters(
{"input": 3.0},
{"input": np.float32(3.0)},
Expand Down

0 comments on commit a3b3304

Please sign in to comment.