Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

autoscale supporting custom_jvp decorator and primitive. #34

Merged
merged 1 commit into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax
import numpy as np
from jax import core
from jax._src.custom_derivatives import custom_jvp_call_p, custom_vjp_call_p
from jax._src.pjit import pjit_p
from jax._src.util import safe_map

Expand Down Expand Up @@ -190,7 +191,8 @@ def promote_to_scaled_array(val):

if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE:
# Using normal JAX primitive: no scaled inputs, and not always scale rule.
outvals = eqn.primitive.bind(*invals, **eqn.params)
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params)
elif scaled_prim_fn is None:
raise NotImplementedError(
f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet."
Expand Down Expand Up @@ -232,3 +234,27 @@ def scaled_pjit_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[Scale


register_scaled_op(pjit_p, scaled_pjit_translation)


def scaled_custom_jvp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]:
"""Scaled translation of `custom_jvp_call` primitive. Forwarding the scaled call to sub-jaxpr,
and modifying the underlying `jvp` function.
"""
# [fun, jvp], bind_params = custom_jvp_call_p.get_bind_params(params)
call_closed_jaxpr = params["call_jaxpr"]
# FIXME: re-call the custom_jvp decorator/bind.
call_subfunc = partial(autoscale_jaxpr, call_closed_jaxpr.jaxpr, call_closed_jaxpr.literals)
return call_subfunc(*args)


register_scaled_op(custom_jvp_call_p, scaled_custom_jvp_call_translation)


def scaled_custom_vjp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]:
"""Scaled translation of `custom_vjp_call` primitive. Forwarding the scaled call to sub-jaxpr,
and modifying the underlying `vjp` function.
"""
raise NotImplementedError("Scaled custom VJP primitive not yet supported.")


register_scaled_op(custom_vjp_call_p, scaled_custom_vjp_call_translation)
10 changes: 10 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,13 @@ def scaled_log(val: ScaledArray) -> ScaledArray:
@core.register_scaled_lax_op
def scaled_select_n(which: jax.Array, *cases: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.select_n_p, [which, *cases])


@core.register_scaled_lax_op
def scaled_cos(val: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.cos_p, [val])


@core.register_scaled_lax_op
def scaled_sin(val: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.sin_p, [val])
43 changes: 42 additions & 1 deletion tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import chex
import jax
import jax.lax
import jax.numpy as jnp
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from jax_scaled_arithmetics.core import ScaledArray, autoscale, is_scaled_leaf, register_scaled_op, scaled_array
from jax_scaled_arithmetics.core import (
ScaledArray,
asarray,
autoscale,
is_scaled_leaf,
register_scaled_op,
scaled_array,
)
from jax_scaled_arithmetics.core.interpreters import promote_scalar_to_scaled_array


Expand Down Expand Up @@ -129,6 +137,39 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn,
assert scaled_out.dtype == exp_out.dtype
npt.assert_array_almost_equal(scaled_out, exp_out, decimal=4)

@chex.variants(with_jit=True, without_jit=True)
def test__autoscale_decorator__custom_jvp__proper_graph_transformation_and_result(self):
# JAX official `jvp` example.
@jax.custom_jvp
def f(x, y):
return jnp.sin(x) * y

@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
return primal_out, tangent_out

def fn(x, y):
return jax.jvp(f, (x, y), (x, y))

# `autoscale` on `custom_jvp` method.
scaled_inputs = (
scaled_array([-2.0, 0.5], 0.5, dtype=np.float32),
scaled_array([1.5, -4.5], 2, dtype=np.float32),
)
scaled_primals, scaled_tangents = self.variant(autoscale(fn))(*scaled_inputs)
# JAX default/expected values
inputs = tuple(map(asarray, scaled_inputs))
primals, tangents = self.variant(fn)(*inputs)

assert isinstance(scaled_primals, ScaledArray)
assert isinstance(scaled_tangents, ScaledArray)
npt.assert_array_almost_equal(scaled_primals, primals)
npt.assert_array_almost_equal(scaled_tangents, tangents)

@parameterized.parameters(
{"input": np.array(3)},
{"input": jnp.array(3)},
Expand Down