Skip to content

Commit

Permalink
Support inputs complex PyTree in autoscale.
Browse files Browse the repository at this point in the history
Properly using `jax.tree_map` instead of basic map, and inputs flattening.
  • Loading branch information
balancap committed Nov 24, 2023
1 parent bdda55a commit 34ea4de
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
10 changes: 7 additions & 3 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from jax._src.pjit import pjit_p
from jax._src.util import safe_map

from .datatype import NDArray, ScaledArray
from .datatype import NDArray, ScaledArray, is_scaled_leaf


class ScaledPrimitiveType(IntEnum):
Expand Down Expand Up @@ -138,14 +138,16 @@ def wrapped(*args, **kwargs):
if len(kwargs) > 0:
raise NotImplementedError("`autoscale` JAX interpreter not supporting named tensors at present.")

aval_args = safe_map(lambda x: _get_aval(x), args)
aval_args = jax.tree_map(_get_aval, args, is_leaf=is_scaled_leaf)
# Get jaxpr of unscaled/normal graph. Getting output Pytree shape as well.
closed_jaxpr, outshape = jax.make_jaxpr(fun, return_shape=True)(*aval_args, **kwargs)
out_leaves, out_pytree = jax.tree_util.tree_flatten(outshape)

# Flattening of PyTree inputs.
inputs_scaled = args
inputs_scaled_flat, _ = jax.tree_util.tree_flatten(inputs_scaled, is_leaf=is_scaled_leaf)
# Trace the graph & convert to scaled one.
outputs_scaled_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *inputs_scaled)
outputs_scaled_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *inputs_scaled_flat)
# Reconstruct the output Pytree, with scaled arrays.
# NOTE: this step is also handling single vs multi outputs.
assert len(out_leaves) == len(outputs_scaled_flat)
Expand Down Expand Up @@ -174,6 +176,8 @@ def promote_to_scaled_array(val):
# No promotion rule => just return as such.
return val

# A few initial checks to make sure there is consistency.
assert len(jaxpr.invars) == len(args)
safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)

Expand Down
14 changes: 12 additions & 2 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,20 @@ def myfunc(x):
assert jaxpr.outvars[0].aval.shape == scaled_input.shape
assert jaxpr.outvars[1].aval.shape == ()

@chex.variants(with_jit=False, without_jit=True)
@chex.variants(with_jit=True, without_jit=True)
@parameterized.parameters(
# Identity function!
{"fn": lambda x: x, "inputs": [scaled_array([1.0, 2.0], 3, dtype=np.float32)]},
# Non-trivial input JAX pytree.
{
"fn": lambda vals: vals["x"] * vals["y"],
"inputs": [
{
"x": scaled_array([1.0, 2.0], 3, dtype=np.float32),
"y": scaled_array([1.5, -2.5], 2, dtype=np.float32),
}
],
},
# Non-trivial output JAX pytree
{"fn": lambda x: {"x": (x,)}, "inputs": [scaled_array([1.0, 2.0], 3, dtype=np.float32)]},
# Multi-inputs operation.
Expand Down Expand Up @@ -103,7 +113,7 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn,
scaled_fn = self.variant(autoscale(fn))
scaled_output = scaled_fn(*inputs)
# Normal JAX path, without scaled arrays.
raw_inputs = [np.asarray(v) for v in inputs]
raw_inputs = jax.tree_map(np.asarray, inputs, is_leaf=is_scaled_leaf)
expected_output = self.variant(fn)(*raw_inputs)

# Do we re-construct properly the output type (i.e. handling Pytree properly)?
Expand Down

0 comments on commit 34ea4de

Please sign in to comment.