From e590d6f86052d77204303696771433f122c65751 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 24 Nov 2023 14:30:28 +0000 Subject: [PATCH] Support inputs complex PyTree in `autoscale`. (#33) Properly using `jax.tree_map` instead of basic map, and inputs flattening. --- jax_scaled_arithmetics/core/interpreters.py | 10 +++++++--- tests/core/test_interpreter.py | 14 ++++++++++++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index d4ab441..4010c94 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -9,7 +9,7 @@ from jax._src.pjit import pjit_p from jax._src.util import safe_map -from .datatype import NDArray, ScaledArray +from .datatype import NDArray, ScaledArray, is_scaled_leaf class ScaledPrimitiveType(IntEnum): @@ -138,14 +138,16 @@ def wrapped(*args, **kwargs): if len(kwargs) > 0: raise NotImplementedError("`autoscale` JAX interpreter not supporting named tensors at present.") - aval_args = safe_map(lambda x: _get_aval(x), args) + aval_args = jax.tree_map(_get_aval, args, is_leaf=is_scaled_leaf) # Get jaxpr of unscaled/normal graph. Getting output Pytree shape as well. closed_jaxpr, outshape = jax.make_jaxpr(fun, return_shape=True)(*aval_args, **kwargs) out_leaves, out_pytree = jax.tree_util.tree_flatten(outshape) + # Flattening of PyTree inputs. inputs_scaled = args + inputs_scaled_flat, _ = jax.tree_util.tree_flatten(inputs_scaled, is_leaf=is_scaled_leaf) # Trace the graph & convert to scaled one. - outputs_scaled_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *inputs_scaled) + outputs_scaled_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *inputs_scaled_flat) # Reconstruct the output Pytree, with scaled arrays. # NOTE: this step is also handling single vs multi outputs. assert len(out_leaves) == len(outputs_scaled_flat) @@ -174,6 +176,8 @@ def promote_to_scaled_array(val): # No promotion rule => just return as such. return val + # A few initial checks to make sure there is consistency. + assert len(jaxpr.invars) == len(args) safe_map(write, jaxpr.invars, args) safe_map(write, jaxpr.constvars, consts) diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index c65bc6d..896f97a 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -66,10 +66,20 @@ def myfunc(x): assert jaxpr.outvars[0].aval.shape == scaled_input.shape assert jaxpr.outvars[1].aval.shape == () - @chex.variants(with_jit=False, without_jit=True) + @chex.variants(with_jit=True, without_jit=True) @parameterized.parameters( # Identity function! {"fn": lambda x: x, "inputs": [scaled_array([1.0, 2.0], 3, dtype=np.float32)]}, + # Non-trivial input JAX pytree. + { + "fn": lambda vals: vals["x"] * vals["y"], + "inputs": [ + { + "x": scaled_array([1.0, 2.0], 3, dtype=np.float32), + "y": scaled_array([1.5, -2.5], 2, dtype=np.float32), + } + ], + }, # Non-trivial output JAX pytree {"fn": lambda x: {"x": (x,)}, "inputs": [scaled_array([1.0, 2.0], 3, dtype=np.float32)]}, # Multi-inputs operation. @@ -103,7 +113,7 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn, scaled_fn = self.variant(autoscale(fn)) scaled_output = scaled_fn(*inputs) # Normal JAX path, without scaled arrays. - raw_inputs = [np.asarray(v) for v in inputs] + raw_inputs = jax.tree_map(np.asarray, inputs, is_leaf=is_scaled_leaf) expected_output = self.variant(fn)(*raw_inputs) # Do we re-construct properly the output type (i.e. handling Pytree properly)?