diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 9add00a..d62773f 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import IntEnum from functools import partial, wraps -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import jax import numpy as np @@ -14,8 +14,9 @@ custom_vjp_call_p, ) from jax._src.util import safe_map +from jax.tree_util import register_pytree_node_class -from .datatype import Array, DTypeLike, ScaledArray, as_scaled_array_base, is_scaled_leaf +from .datatype import Array, ArrayTypes, DTypeLike, ScaledArray, Shape, as_scaled_array_base, is_scaled_leaf from .utils import Pow2RoundMode @@ -68,7 +69,17 @@ class ScaledPrimitiveType(IntEnum): ALWAYS_SCALE = 2 +_scaled_jaxpr_ops_registry: Dict[core.Primitive, Any] = {} +"""Registry of (sub) "jaxpr" ops/primitives and their scaled translation. + +The core "jaxpr" primitives are typical `pjit`, `xla_call`, where the JSA interpreter +will need to be run on sub-jaxprs, passing the full metadata on input/output tensors. +""" + + _scaled_ops_registry: Dict[core.Primitive, Tuple[Any, ScaledPrimitiveType]] = {} +"""Registry of JAX common primitives and their scaled translation. +""" def _get_lax_prim(scaled_func: Any) -> core.Primitive: @@ -147,6 +158,88 @@ def find_registered_scaled_op(prim: core.Primitive) -> Tuple[Any, ScaledPrimitiv return _scaled_ops_registry.get(prim, (None, ScaledPrimitiveType.NEVER)) +def promote_to_scaled_array(val, scale_dtype: Optional[DTypeLike] = None): + if isinstance(val, ScaledArray): + return val + elif np.ndim(val) == 0: + return promote_scalar_to_scaled_array(val, scale_dtype) + # No promotion rule => just return as such. + return val + + +def convert_python_scalar(val: Any) -> Any: + """Convert Python scalar to Numpy if necessary.""" + if isinstance(val, int): + return np.int32(val) + elif isinstance(val, float): + return np.float32(val) + return val + + +@register_pytree_node_class +@dataclass(frozen=True, init=False) +class ScalifyTracerArray: + """Meta-Array class used in scalify tracer. It can represent + any array, scaled or not, and tracks whether an array corresponds to a scalar broadcasted. + + + Args: + array: Array component, if it is a normal array. + scaled_array: Scaled array, if representing a scaled array. + is_broadcasted_scalar: Is the array a broadcasted scalar. + """ + + array: Union[Array, ScaledArray] = None + is_broadcasted_scalar: bool = False + + def __init__(self, arr: Union[Array, ScaledArray], is_broadcasted_scalar: Optional[bool] = None) -> None: + # Convert Python scalars, if necessary. + arr = convert_python_scalar(arr) + assert isinstance(arr, (np.number, np.ndarray, ScaledArray, *ArrayTypes)) + object.__setattr__(self, "array", arr) + # Optional is broadcasted scalar information. + is_scalar = self.array.size == 1 + is_broadcasted_scalar = is_scalar if is_broadcasted_scalar is None else is_broadcasted_scalar or is_scalar + object.__setattr__(self, "is_broadcasted_scalar", is_broadcasted_scalar) + + def tree_flatten(self): + # See official JAX documentation on extending PyTrees. + # Note: using explicit tree flatten instead of chex for MyPy compatibility. + children = (self.array,) + aux_data = (self.is_broadcasted_scalar,) + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + # See official JAX documentation on extending PyTrees. + assert len(aux_data) == 1 + assert len(children) == 1 + return cls(children[0], aux_data[0]) + + @property + def size(self) -> int: + return self.array.size + + @property + def shape(self) -> Shape: + return self.array.shape + + @property + def is_scaled_array(self) -> bool: + return isinstance(self.array, ScaledArray) + + def to_scaled_array(self, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray: + if self.is_scaled_array: + return self.array + # TODO: improve the logic for broadcasted scalar arrays! + return promote_to_scaled_array(self.array, scale_dtype) + + def to_array(self) -> Array: + if not self.is_scaled_array: + return self.array + return self.array.to_array() + + def autoscale(fun): """`autoscale` JAX graph transformation. @@ -174,8 +267,12 @@ def wrapped(*args, **kwargs): # Flattening of PyTree inputs. inputs_scaled = args inputs_scaled_flat, _ = jax.tree_util.tree_flatten(inputs_scaled, is_leaf=is_scaled_leaf) + # Convert to Scalify tracer (meta) arrays. + inputs_tracer_flat = list(map(ScalifyTracerArray, inputs_scaled_flat)) + consts_tracer_flat = list(map(ScalifyTracerArray, closed_jaxpr.literals)) # Trace the graph & convert to scaled one. - outputs_scaled_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *inputs_scaled_flat) + outputs_tracer_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, consts_tracer_flat, *inputs_tracer_flat) + outputs_scaled_flat = [v.array for v in outputs_tracer_flat] # Reconstruct the output Pytree, with scaled arrays. # NOTE: this step is also handling single vs multi outputs. assert len(out_leaves) == len(outputs_scaled_flat) @@ -185,82 +282,89 @@ def wrapped(*args, **kwargs): return wrapped -def autoscale_jaxpr(jaxpr: core.Jaxpr, consts, *args): - env: Dict[core.Var, ScaledArray] = {} +def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Sequence[core.ShapedArray]: + """Bind a Jaxpr equation to arrays.""" + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params) + if not eqn.primitive.multiple_results: + outvals = [outvals] + return outvals + + +def autoscale_jaxpr(jaxpr: core.Jaxpr, consts: Sequence[ScalifyTracerArray], *args: ScalifyTracerArray): + env: Dict[core.Var, ScalifyTracerArray] = {} # Check dtype consistency between normal and scaled modes. safe_check_dtypes: bool = False # AutoScale config to use. autoscale_cfg = get_autoscale_config() - def read(var): + def read(var) -> ScalifyTracerArray: if type(var) is core.Literal: - return var.val + # Wrap the constant in tracer array. + return ScalifyTracerArray(var.val) return env[var] - def write(var, val): + def write(var, val: ScalifyTracerArray): + assert isinstance(val, ScalifyTracerArray) env[var] = val - def promote_to_scaled_array(val, scale_dtype): - if isinstance(val, ScaledArray): - return val - elif np.ndim(val) == 0: - return promote_scalar_to_scaled_array(val, scale_dtype) - # No promotion rule => just return as such. - return val - - def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Sequence[core.ShapedArray]: - subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) - outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params) - if not eqn.primitive.multiple_results: - outvals = [outvals] - return outvals - # A few initial checks to make sure there is consistency. assert len(jaxpr.invars) == len(args) safe_map(write, jaxpr.invars, args) safe_map(write, jaxpr.constvars, consts) for eqn in jaxpr.eqns: - invals = safe_map(read, eqn.invars) - # Is there any ScaledArray among inputs? - any_scaled_inputs = any([isinstance(v, ScaledArray) for v in invals]) - # Is there a scaled primitive associated? - scaled_prim_fn, scaled_prim_type = _scaled_ops_registry.get(eqn.primitive, (None, ScaledPrimitiveType.NEVER)) - - if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE: - # Using normal JAX primitive: no scaled inputs, and not always scale rule. - outvals = jaxpr_eqn_bind(eqn, invals) - elif scaled_prim_fn is None: - raise NotImplementedError( - f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet." - ) + invals_tracer: List[ScalifyTracerArray] = safe_map(read, eqn.invars) + if eqn.primitive in _scaled_jaxpr_ops_registry: + # Core sub-jaxpr primitive => pass the complete tracer array with metadata. + scaled_jaxpr_prim_fn = _scaled_jaxpr_ops_registry[eqn.primitive] + outvals_tracer = scaled_jaxpr_prim_fn(*invals_tracer, **eqn.params) else: - # Using scaled primitive. Automatic promotion of inputs to scaled array, when possible. - scaled_invals = list(map(lambda v: promote_to_scaled_array(v, autoscale_cfg.scale_dtype), invals)) - outvals = scaled_prim_fn(*scaled_invals, **eqn.params) - if not eqn.primitive.multiple_results: - outvals = [outvals] - - # Check consistency with normal JAX mode. Help catching dtype promotion errors. - # NOTE: ignoring when no outputs! (e.g. debug_callback). - if safe_check_dtypes and len(outvals) > 0: - ref_outvals = jaxpr_eqn_bind(eqn, [_get_data(v) for v in invals]) - data_outvals = [_get_data(v) for v in outvals] - # Check scaled dtypes == ref dtypes. - ref_dtypes = tuple(v.dtype for v in ref_outvals) - data_dtypes = tuple(v.dtype for v in data_outvals) - if data_dtypes != ref_dtypes: - raise ValueError( - f"Output dtype of '{eqn.primitive}' scaled translation is not consistent with the JAX reference primitive implementation: {data_dtypes} vs {ref_dtypes}." - ) - - safe_map(write, eqn.outvars, outvals) - - outvals = safe_map(read, jaxpr.outvars) - return outvals - + # Common primitives path. + # Is there any ScaledArray among inputs? + any_scaled_inputs = any([v.is_scaled_array for v in invals_tracer]) + # Is there a scaled primitive associated? + scaled_prim_fn, scaled_prim_type = _scaled_ops_registry.get( + eqn.primitive, (None, ScaledPrimitiveType.NEVER) + ) -def scaled_pjit_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[ScaledArray]: + if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE: + # Using normal JAX primitive: no scaled inputs, and not always scale rule. + invals = [v.to_array() for v in invals_tracer] + outvals = jaxpr_eqn_bind(eqn, invals) + outvals_tracer = [ScalifyTracerArray(v) for v in outvals] + elif scaled_prim_fn is None: + raise NotImplementedError( + f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet." + ) + else: + # Using scaled primitive. Automatic promotion of inputs to scaled array, when possible. + scaled_invals = [v.to_scaled_array(autoscale_cfg.scale_dtype) for v in invals_tracer] + outvals = scaled_prim_fn(*scaled_invals, **eqn.params) + if not eqn.primitive.multiple_results: + outvals = [outvals] + outvals_tracer = [ScalifyTracerArray(v) for v in outvals] + + # Check consistency with normal JAX mode. Help catching dtype promotion errors. + # NOTE: ignoring when no outputs! (e.g. debug_callback). + if safe_check_dtypes and len(outvals) > 0: + ref_outvals = jaxpr_eqn_bind(eqn, [_get_data(v.array) for v in invals_tracer]) + data_outvals = [_get_data(v) for v in outvals] + # Check scaled dtypes == ref dtypes. + ref_dtypes = tuple(v.dtype for v in ref_outvals) + data_dtypes = tuple(v.dtype for v in data_outvals) + if data_dtypes != ref_dtypes: + raise ValueError( + f"Output dtype of '{eqn.primitive}' scaled translation is not consistent with the JAX reference primitive implementation: {data_dtypes} vs {ref_dtypes}." + ) + + safe_map(write, eqn.outvars, outvals_tracer) + + outvals_tracer = safe_map(read, jaxpr.outvars) + return outvals_tracer + + +def scaled_pjit_translation(*args: ScalifyTracerArray, **kwargs: Any) -> Sequence[ScalifyTracerArray]: """Scaled translation of `pjit`. Basically re-running `autoscale` on sub-jaxpr. NOTE: the `pjit` call will be kept, forwarding the proper parameters (shardings, ...). @@ -274,24 +378,26 @@ def scaled_pjit_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[Scale # in_shardings = kwargs["in_shardings"] # out_shardings = kwargs["out_shardings"] + consts_tracer_flat = [ScalifyTracerArray(v) for v in closed_jaxpr.literals] # Generate the sub-scaled function, with proper `jax.jit` options. - subfunc = partial(autoscale_jaxpr, closed_jaxpr.jaxpr, closed_jaxpr.literals) + subfunc = partial(autoscale_jaxpr, closed_jaxpr.jaxpr, consts_tracer_flat) subfunc.__name__ = name # type:ignore + # FIXME => getting a jax.jit added to jaxpr graph. subfunc = jax.jit(subfunc, inline=inline, keep_unused=keep_unused) - - outputs_scaled_flat = subfunc(*args) - return outputs_scaled_flat + outvals = subfunc(*args) + return outvals try: from jax._src.pjit import pjit_p register_scaled_op(pjit_p, scaled_pjit_translation) + _scaled_jaxpr_ops_registry[pjit_p] = scaled_pjit_translation except (ImportError, ModuleNotFoundError): pass -def scaled_xla_call_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[ScaledArray]: +def scaled_xla_call_translation(*args: ScalifyTracerArray, **kwargs: Any) -> Sequence[ScalifyTracerArray]: """Scaled translation of `xla_call`. Basically re-running `autoscale` on sub-jaxpr. Useful for JAX 0.3 compatibility @@ -310,7 +416,6 @@ def scaled_xla_call_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[S subfunc = partial(autoscale_jaxpr, jaxpr, []) subfunc.__name__ = name # type:ignore subfunc = jax.jit(subfunc, inline=inline, keep_unused=keep_unused) - outputs_scaled_flat = subfunc(*args) return outputs_scaled_flat @@ -319,11 +424,12 @@ def scaled_xla_call_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[S from jax.interpreters.xla import xla_call_p register_scaled_op(xla_call_p, scaled_xla_call_translation) + _scaled_jaxpr_ops_registry[xla_call_p] = scaled_xla_call_translation except (ImportError, ModuleNotFoundError): pass -def scaled_custom_jvp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]: +def scaled_custom_jvp_call_translation(*args: ScalifyTracerArray, **params: Any) -> Sequence[ScalifyTracerArray]: """Scaled translation of `custom_jvp_call` primitive. Forwarding the scaled call to sub-jaxpr, and modifying the underlying `jvp` function. """ @@ -340,8 +446,11 @@ def scaled_custom_jvp_call_translation(*args: ScaledArray, **params: Any) -> Seq register_scaled_op(custom_jvp_call_p, scaled_custom_jvp_call_translation) register_scaled_op(custom_jvp_call_jaxpr_p, scaled_custom_jvp_call_translation) +_scaled_jaxpr_ops_registry[custom_jvp_call_p] = scaled_custom_jvp_call_translation +_scaled_jaxpr_ops_registry[custom_jvp_call_jaxpr_p] = scaled_custom_jvp_call_translation -def scaled_custom_vjp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]: + +def scaled_custom_vjp_call_translation(*args: ScalifyTracerArray, **params: Any) -> Sequence[ScalifyTracerArray]: """Scaled translation of `custom_vjp_call` primitive. Forwarding the scaled call to sub-jaxpr, and modifying the underlying `vjp` function. """ @@ -354,3 +463,6 @@ def scaled_custom_vjp_call_translation(*args: ScaledArray, **params: Any) -> Seq register_scaled_op(custom_vjp_call_p, scaled_custom_vjp_call_translation) register_scaled_op(custom_vjp_call_jaxpr_p, scaled_custom_vjp_call_translation) + +_scaled_jaxpr_ops_registry[custom_vjp_call_p] = scaled_custom_vjp_call_translation +_scaled_jaxpr_ops_registry[custom_vjp_call_jaxpr_p] = scaled_custom_vjp_call_translation diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index abebc69..6333fb7 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -19,7 +19,56 @@ register_scaled_op, scaled_array, ) -from jax_scaled_arithmetics.core.interpreters import promote_scalar_to_scaled_array +from jax_scaled_arithmetics.core.interpreters import ScalifyTracerArray, promote_scalar_to_scaled_array + + +class ScalifyTracerArrayTests(chex.TestCase): + @parameterized.parameters( + {"arr": True}, + {"arr": 2}, + {"arr": 3.0}, + ) + def test__scalify_tracer_array__init__from_python_value(self, arr): + tracer_arr = ScalifyTracerArray(arr) + assert tracer_arr.array == arr + assert not tracer_arr.is_scaled_array + assert tracer_arr.is_broadcasted_scalar == (tracer_arr.size == 1) + assert tracer_arr.to_array() is tracer_arr.array + + @parameterized.parameters( + {"arr": np.float32(2)}, + {"arr": np.array([1, 2])}, + {"arr": jnp.array([3, 4])}, + ) + def test__scalify_tracer_array__init__from_normal_array(self, arr): + tracer_arr = ScalifyTracerArray(arr) + assert tracer_arr.array is arr + assert not tracer_arr.is_scaled_array + assert tracer_arr.is_broadcasted_scalar == (tracer_arr.size == 1) + assert tracer_arr.to_array() is tracer_arr.array + + @parameterized.parameters({"arr": scaled_array([1, 2], 3.0)}) + def test__scalify_tracer_array__init__from_scaled_array(self, arr): + tracer_arr = ScalifyTracerArray(arr) + assert tracer_arr.array is arr + assert tracer_arr.is_scaled_array + assert tracer_arr.to_scaled_array() is tracer_arr.array + + def test__scalify_tracer_array__init__is_broadcasted_scalar_kwarg(self): + arr = scaled_array([1, 2], 3.0) + assert ScalifyTracerArray(arr, is_broadcasted_scalar=True).is_broadcasted_scalar + assert not ScalifyTracerArray(arr, is_broadcasted_scalar=False).is_broadcasted_scalar + + def test__scalify_tracer_array__flatten__proper_pytree(self): + arr = scaled_array([1, 2], 3.0) + tracer_arr_in = ScalifyTracerArray(arr, True) + # Proper round trip! + flat_arrays, pytree = jax.tree_util.tree_flatten(tracer_arr_in) + tracer_arr_out = jax.tree_util.tree_unflatten(pytree, flat_arrays) + + assert isinstance(tracer_arr_out, ScalifyTracerArray) + assert tracer_arr_out.is_broadcasted_scalar == tracer_arr_in.is_broadcasted_scalar + npt.assert_array_equal(np.asarray(tracer_arr_out.array), np.asarray(tracer_arr_in.array)) class AutoScaleInterpreterTests(chex.TestCase):