Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inputs complex PyTree in autoscale. #33

Merged
merged 1 commit into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from jax._src.pjit import pjit_p
from jax._src.util import safe_map

from .datatype import NDArray, ScaledArray
from .datatype import NDArray, ScaledArray, is_scaled_leaf


class ScaledPrimitiveType(IntEnum):
Expand Down Expand Up @@ -138,14 +138,16 @@ def wrapped(*args, **kwargs):
if len(kwargs) > 0:
raise NotImplementedError("`autoscale` JAX interpreter not supporting named tensors at present.")

aval_args = safe_map(lambda x: _get_aval(x), args)
aval_args = jax.tree_map(_get_aval, args, is_leaf=is_scaled_leaf)
# Get jaxpr of unscaled/normal graph. Getting output Pytree shape as well.
closed_jaxpr, outshape = jax.make_jaxpr(fun, return_shape=True)(*aval_args, **kwargs)
out_leaves, out_pytree = jax.tree_util.tree_flatten(outshape)

# Flattening of PyTree inputs.
inputs_scaled = args
inputs_scaled_flat, _ = jax.tree_util.tree_flatten(inputs_scaled, is_leaf=is_scaled_leaf)
# Trace the graph & convert to scaled one.
outputs_scaled_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *inputs_scaled)
outputs_scaled_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *inputs_scaled_flat)
# Reconstruct the output Pytree, with scaled arrays.
# NOTE: this step is also handling single vs multi outputs.
assert len(out_leaves) == len(outputs_scaled_flat)
Expand Down Expand Up @@ -174,6 +176,8 @@ def promote_to_scaled_array(val):
# No promotion rule => just return as such.
return val

# A few initial checks to make sure there is consistency.
assert len(jaxpr.invars) == len(args)
safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)

Expand Down
14 changes: 12 additions & 2 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,20 @@ def myfunc(x):
assert jaxpr.outvars[0].aval.shape == scaled_input.shape
assert jaxpr.outvars[1].aval.shape == ()

@chex.variants(with_jit=False, without_jit=True)
@chex.variants(with_jit=True, without_jit=True)
@parameterized.parameters(
# Identity function!
{"fn": lambda x: x, "inputs": [scaled_array([1.0, 2.0], 3, dtype=np.float32)]},
# Non-trivial input JAX pytree.
{
"fn": lambda vals: vals["x"] * vals["y"],
"inputs": [
{
"x": scaled_array([1.0, 2.0], 3, dtype=np.float32),
"y": scaled_array([1.5, -2.5], 2, dtype=np.float32),
}
],
},
# Non-trivial output JAX pytree
{"fn": lambda x: {"x": (x,)}, "inputs": [scaled_array([1.0, 2.0], 3, dtype=np.float32)]},
# Multi-inputs operation.
Expand Down Expand Up @@ -103,7 +113,7 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn,
scaled_fn = self.variant(autoscale(fn))
scaled_output = scaled_fn(*inputs)
# Normal JAX path, without scaled arrays.
raw_inputs = [np.asarray(v) for v in inputs]
raw_inputs = jax.tree_map(np.asarray, inputs, is_leaf=is_scaled_leaf)
expected_output = self.variant(fn)(*raw_inputs)

# Do we re-construct properly the output type (i.e. handling Pytree properly)?
Expand Down