Skip to content

Commit

Permalink
autoscale supporting custom_jvp decorator and primitive. (#34)
Browse files Browse the repository at this point in the history
Forwarding ScaledArray inside `custom_jvp` calls.
  • Loading branch information
balancap authored Nov 24, 2023
1 parent e590d6f commit ba6115a
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 2 deletions.
28 changes: 27 additions & 1 deletion jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax
import numpy as np
from jax import core
from jax._src.custom_derivatives import custom_jvp_call_p, custom_vjp_call_p
from jax._src.pjit import pjit_p
from jax._src.util import safe_map

Expand Down Expand Up @@ -190,7 +191,8 @@ def promote_to_scaled_array(val):

if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE:
# Using normal JAX primitive: no scaled inputs, and not always scale rule.
outvals = eqn.primitive.bind(*invals, **eqn.params)
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params)
elif scaled_prim_fn is None:
raise NotImplementedError(
f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet."
Expand Down Expand Up @@ -232,3 +234,27 @@ def scaled_pjit_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[Scale


register_scaled_op(pjit_p, scaled_pjit_translation)


def scaled_custom_jvp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]:
"""Scaled translation of `custom_jvp_call` primitive. Forwarding the scaled call to sub-jaxpr,
and modifying the underlying `jvp` function.
"""
# [fun, jvp], bind_params = custom_jvp_call_p.get_bind_params(params)
call_closed_jaxpr = params["call_jaxpr"]
# FIXME: re-call the custom_jvp decorator/bind.
call_subfunc = partial(autoscale_jaxpr, call_closed_jaxpr.jaxpr, call_closed_jaxpr.literals)
return call_subfunc(*args)


register_scaled_op(custom_jvp_call_p, scaled_custom_jvp_call_translation)


def scaled_custom_vjp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]:
"""Scaled translation of `custom_vjp_call` primitive. Forwarding the scaled call to sub-jaxpr,
and modifying the underlying `vjp` function.
"""
raise NotImplementedError("Scaled custom VJP primitive not yet supported.")


register_scaled_op(custom_vjp_call_p, scaled_custom_vjp_call_translation)
10 changes: 10 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,13 @@ def scaled_log(val: ScaledArray) -> ScaledArray:
@core.register_scaled_lax_op
def scaled_select_n(which: jax.Array, *cases: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.select_n_p, [which, *cases])


@core.register_scaled_lax_op
def scaled_cos(val: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.cos_p, [val])


@core.register_scaled_lax_op
def scaled_sin(val: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.sin_p, [val])
43 changes: 42 additions & 1 deletion tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import chex
import jax
import jax.lax
import jax.numpy as jnp
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from jax_scaled_arithmetics.core import ScaledArray, autoscale, is_scaled_leaf, register_scaled_op, scaled_array
from jax_scaled_arithmetics.core import (
ScaledArray,
asarray,
autoscale,
is_scaled_leaf,
register_scaled_op,
scaled_array,
)
from jax_scaled_arithmetics.core.interpreters import promote_scalar_to_scaled_array


Expand Down Expand Up @@ -129,6 +137,39 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn,
assert scaled_out.dtype == exp_out.dtype
npt.assert_array_almost_equal(scaled_out, exp_out, decimal=4)

@chex.variants(with_jit=True, without_jit=True)
def test__autoscale_decorator__custom_jvp__proper_graph_transformation_and_result(self):
# JAX official `jvp` example.
@jax.custom_jvp
def f(x, y):
return jnp.sin(x) * y

@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
return primal_out, tangent_out

def fn(x, y):
return jax.jvp(f, (x, y), (x, y))

# `autoscale` on `custom_jvp` method.
scaled_inputs = (
scaled_array([-2.0, 0.5], 0.5, dtype=np.float32),
scaled_array([1.5, -4.5], 2, dtype=np.float32),
)
scaled_primals, scaled_tangents = self.variant(autoscale(fn))(*scaled_inputs)
# JAX default/expected values
inputs = tuple(map(asarray, scaled_inputs))
primals, tangents = self.variant(fn)(*inputs)

assert isinstance(scaled_primals, ScaledArray)
assert isinstance(scaled_tangents, ScaledArray)
npt.assert_array_almost_equal(scaled_primals, primals)
npt.assert_array_almost_equal(scaled_tangents, tangents)

@parameterized.parameters(
{"input": np.array(3)},
{"input": jnp.array(3)},
Expand Down

0 comments on commit ba6115a

Please sign in to comment.