Skip to content

Commit

Permalink
Implement AutoScale/Scalify TracerMetaArray data structure.
Browse files Browse the repository at this point in the history
Introducing the dataclass `ScalifyTracerArray` in JSA interpreter/tracer in order
to be able to pass additional metadata on the array (e.g. whether it is a broadcasted
scalar tensor).
  • Loading branch information
balancap committed Jan 26, 2024
1 parent 9b862ef commit 87c195b
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 70 deletions.
250 changes: 181 additions & 69 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from enum import IntEnum
from functools import partial, wraps
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import jax
import numpy as np
Expand All @@ -14,8 +14,9 @@
custom_vjp_call_p,
)
from jax._src.util import safe_map
from jax.tree_util import register_pytree_node_class

from .datatype import Array, DTypeLike, ScaledArray, as_scaled_array_base, is_scaled_leaf
from .datatype import Array, ArrayTypes, DTypeLike, ScaledArray, Shape, as_scaled_array_base, is_scaled_leaf
from .utils import Pow2RoundMode


Expand Down Expand Up @@ -68,7 +69,17 @@ class ScaledPrimitiveType(IntEnum):
ALWAYS_SCALE = 2


_scaled_jaxpr_ops_registry: Dict[core.Primitive, Any] = {}
"""Registry of (sub) "jaxpr" ops/primitives and their scaled translation.
The core "jaxpr" primitives are typical `pjit`, `xla_call`, where the JSA interpreter
will need to be run on sub-jaxprs, passing the full metadata on input/output tensors.
"""


_scaled_ops_registry: Dict[core.Primitive, Tuple[Any, ScaledPrimitiveType]] = {}
"""Registry of JAX common primitives and their scaled translation.
"""


def _get_lax_prim(scaled_func: Any) -> core.Primitive:
Expand Down Expand Up @@ -147,6 +158,88 @@ def find_registered_scaled_op(prim: core.Primitive) -> Tuple[Any, ScaledPrimitiv
return _scaled_ops_registry.get(prim, (None, ScaledPrimitiveType.NEVER))


def promote_to_scaled_array(val, scale_dtype: Optional[DTypeLike] = None):
if isinstance(val, ScaledArray):
return val
elif np.ndim(val) == 0:
return promote_scalar_to_scaled_array(val, scale_dtype)
# No promotion rule => just return as such.
return val


def convert_python_scalar(val: Any) -> Any:
"""Convert Python scalar to Numpy if necessary."""
if isinstance(val, int):
return np.int32(val)
elif isinstance(val, float):
return np.float32(val)
return val


@register_pytree_node_class
@dataclass(frozen=True, init=False)
class ScalifyTracerArray:
"""Meta-Array class used in scalify tracer. It can represent
any array, scaled or not, and tracks whether an array corresponds to a scalar broadcasted.
Args:
array: Array component, if it is a normal array.
scaled_array: Scaled array, if representing a scaled array.
is_broadcasted_scalar: Is the array a broadcasted scalar.
"""

array: Union[Array, ScaledArray] = None
is_broadcasted_scalar: bool = False

def __init__(self, arr: Union[Array, ScaledArray], is_broadcasted_scalar: Optional[bool] = None) -> None:
# Convert Python scalars, if necessary.
arr = convert_python_scalar(arr)
assert isinstance(arr, (np.number, np.ndarray, ScaledArray, *ArrayTypes))
object.__setattr__(self, "array", arr)
# Optional is broadcasted scalar information.
is_scalar = self.array.size == 1
is_broadcasted_scalar = is_scalar if is_broadcasted_scalar is None else is_broadcasted_scalar or is_scalar
object.__setattr__(self, "is_broadcasted_scalar", is_broadcasted_scalar)

def tree_flatten(self):
# See official JAX documentation on extending PyTrees.
# Note: using explicit tree flatten instead of chex for MyPy compatibility.
children = (self.array,)
aux_data = (self.is_broadcasted_scalar,)
return (children, aux_data)

@classmethod
def tree_unflatten(cls, aux_data, children):
# See official JAX documentation on extending PyTrees.
assert len(aux_data) == 1
assert len(children) == 1
return cls(children[0], aux_data[0])

@property
def size(self) -> int:
return self.array.size

@property
def shape(self) -> Shape:
return self.array.shape

@property
def is_scaled_array(self) -> bool:
return isinstance(self.array, ScaledArray)

def to_scaled_array(self, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray:
if self.is_scaled_array:
return self.array
# TODO: improve the logic for broadcasted scalar arrays!
return promote_to_scaled_array(self.array, scale_dtype)

def to_array(self) -> Array:
if not self.is_scaled_array:
return self.array
return self.array.to_array()


def autoscale(fun):
"""`autoscale` JAX graph transformation.
Expand Down Expand Up @@ -174,8 +267,12 @@ def wrapped(*args, **kwargs):
# Flattening of PyTree inputs.
inputs_scaled = args
inputs_scaled_flat, _ = jax.tree_util.tree_flatten(inputs_scaled, is_leaf=is_scaled_leaf)
# Convert to Scalify tracer (meta) arrays.
inputs_tracer_flat = list(map(ScalifyTracerArray, inputs_scaled_flat))
consts_tracer_flat = list(map(ScalifyTracerArray, closed_jaxpr.literals))
# Trace the graph & convert to scaled one.
outputs_scaled_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *inputs_scaled_flat)
outputs_tracer_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, consts_tracer_flat, *inputs_tracer_flat)
outputs_scaled_flat = [v.array for v in outputs_tracer_flat]
# Reconstruct the output Pytree, with scaled arrays.
# NOTE: this step is also handling single vs multi outputs.
assert len(out_leaves) == len(outputs_scaled_flat)
Expand All @@ -185,82 +282,89 @@ def wrapped(*args, **kwargs):
return wrapped


def autoscale_jaxpr(jaxpr: core.Jaxpr, consts, *args):
env: Dict[core.Var, ScaledArray] = {}
def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Sequence[core.ShapedArray]:
"""Bind a Jaxpr equation to arrays."""
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
return outvals


def autoscale_jaxpr(jaxpr: core.Jaxpr, consts: Sequence[ScalifyTracerArray], *args: ScalifyTracerArray):
env: Dict[core.Var, ScalifyTracerArray] = {}
# Check dtype consistency between normal and scaled modes.
safe_check_dtypes: bool = False
# AutoScale config to use.
autoscale_cfg = get_autoscale_config()

def read(var):
def read(var) -> ScalifyTracerArray:
if type(var) is core.Literal:
return var.val
# Wrap the constant in tracer array.
return ScalifyTracerArray(var.val)
return env[var]

def write(var, val):
def write(var, val: ScalifyTracerArray):
assert isinstance(val, ScalifyTracerArray)
env[var] = val

def promote_to_scaled_array(val, scale_dtype):
if isinstance(val, ScaledArray):
return val
elif np.ndim(val) == 0:
return promote_scalar_to_scaled_array(val, scale_dtype)
# No promotion rule => just return as such.
return val

def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Sequence[core.ShapedArray]:
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
return outvals

# A few initial checks to make sure there is consistency.
assert len(jaxpr.invars) == len(args)
safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)

for eqn in jaxpr.eqns:
invals = safe_map(read, eqn.invars)
# Is there any ScaledArray among inputs?
any_scaled_inputs = any([isinstance(v, ScaledArray) for v in invals])
# Is there a scaled primitive associated?
scaled_prim_fn, scaled_prim_type = _scaled_ops_registry.get(eqn.primitive, (None, ScaledPrimitiveType.NEVER))

if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE:
# Using normal JAX primitive: no scaled inputs, and not always scale rule.
outvals = jaxpr_eqn_bind(eqn, invals)
elif scaled_prim_fn is None:
raise NotImplementedError(
f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet."
)
invals_tracer: List[ScalifyTracerArray] = safe_map(read, eqn.invars)
if eqn.primitive in _scaled_jaxpr_ops_registry:
# Core sub-jaxpr primitive => pass the complete tracer array with metadata.
scaled_jaxpr_prim_fn = _scaled_jaxpr_ops_registry[eqn.primitive]
outvals_tracer = scaled_jaxpr_prim_fn(*invals_tracer, **eqn.params)
else:
# Using scaled primitive. Automatic promotion of inputs to scaled array, when possible.
scaled_invals = list(map(lambda v: promote_to_scaled_array(v, autoscale_cfg.scale_dtype), invals))
outvals = scaled_prim_fn(*scaled_invals, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]

# Check consistency with normal JAX mode. Help catching dtype promotion errors.
# NOTE: ignoring when no outputs! (e.g. debug_callback).
if safe_check_dtypes and len(outvals) > 0:
ref_outvals = jaxpr_eqn_bind(eqn, [_get_data(v) for v in invals])
data_outvals = [_get_data(v) for v in outvals]
# Check scaled dtypes == ref dtypes.
ref_dtypes = tuple(v.dtype for v in ref_outvals)
data_dtypes = tuple(v.dtype for v in data_outvals)
if data_dtypes != ref_dtypes:
raise ValueError(
f"Output dtype of '{eqn.primitive}' scaled translation is not consistent with the JAX reference primitive implementation: {data_dtypes} vs {ref_dtypes}."
)

safe_map(write, eqn.outvars, outvals)

outvals = safe_map(read, jaxpr.outvars)
return outvals

# Common primitives path.
# Is there any ScaledArray among inputs?
any_scaled_inputs = any([v.is_scaled_array for v in invals_tracer])
# Is there a scaled primitive associated?
scaled_prim_fn, scaled_prim_type = _scaled_ops_registry.get(
eqn.primitive, (None, ScaledPrimitiveType.NEVER)
)

def scaled_pjit_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[ScaledArray]:
if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE:
# Using normal JAX primitive: no scaled inputs, and not always scale rule.
invals = [v.to_array() for v in invals_tracer]
outvals = jaxpr_eqn_bind(eqn, invals)
outvals_tracer = [ScalifyTracerArray(v) for v in outvals]
elif scaled_prim_fn is None:
raise NotImplementedError(
f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet."
)
else:
# Using scaled primitive. Automatic promotion of inputs to scaled array, when possible.
scaled_invals = [v.to_scaled_array(autoscale_cfg.scale_dtype) for v in invals_tracer]
outvals = scaled_prim_fn(*scaled_invals, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
outvals_tracer = [ScalifyTracerArray(v) for v in outvals]

# Check consistency with normal JAX mode. Help catching dtype promotion errors.
# NOTE: ignoring when no outputs! (e.g. debug_callback).
if safe_check_dtypes and len(outvals) > 0:
ref_outvals = jaxpr_eqn_bind(eqn, [_get_data(v.array) for v in invals_tracer])
data_outvals = [_get_data(v) for v in outvals]
# Check scaled dtypes == ref dtypes.
ref_dtypes = tuple(v.dtype for v in ref_outvals)
data_dtypes = tuple(v.dtype for v in data_outvals)
if data_dtypes != ref_dtypes:
raise ValueError(
f"Output dtype of '{eqn.primitive}' scaled translation is not consistent with the JAX reference primitive implementation: {data_dtypes} vs {ref_dtypes}."
)

safe_map(write, eqn.outvars, outvals_tracer)

outvals_tracer = safe_map(read, jaxpr.outvars)
return outvals_tracer


def scaled_pjit_translation(*args: ScalifyTracerArray, **kwargs: Any) -> Sequence[ScalifyTracerArray]:
"""Scaled translation of `pjit`. Basically re-running `autoscale` on sub-jaxpr.
NOTE: the `pjit` call will be kept, forwarding the proper parameters (shardings, ...).
Expand All @@ -274,24 +378,26 @@ def scaled_pjit_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[Scale
# in_shardings = kwargs["in_shardings"]
# out_shardings = kwargs["out_shardings"]

consts_tracer_flat = [ScalifyTracerArray(v) for v in closed_jaxpr.literals]
# Generate the sub-scaled function, with proper `jax.jit` options.
subfunc = partial(autoscale_jaxpr, closed_jaxpr.jaxpr, closed_jaxpr.literals)
subfunc = partial(autoscale_jaxpr, closed_jaxpr.jaxpr, consts_tracer_flat)
subfunc.__name__ = name # type:ignore
# FIXME => getting a jax.jit added to jaxpr graph.
subfunc = jax.jit(subfunc, inline=inline, keep_unused=keep_unused)

outputs_scaled_flat = subfunc(*args)
return outputs_scaled_flat
outvals = subfunc(*args)
return outvals


try:
from jax._src.pjit import pjit_p

register_scaled_op(pjit_p, scaled_pjit_translation)
_scaled_jaxpr_ops_registry[pjit_p] = scaled_pjit_translation
except (ImportError, ModuleNotFoundError):
pass


def scaled_xla_call_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[ScaledArray]:
def scaled_xla_call_translation(*args: ScalifyTracerArray, **kwargs: Any) -> Sequence[ScalifyTracerArray]:
"""Scaled translation of `xla_call`. Basically re-running `autoscale` on sub-jaxpr.
Useful for JAX 0.3 compatibility
Expand All @@ -310,7 +416,6 @@ def scaled_xla_call_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[S
subfunc = partial(autoscale_jaxpr, jaxpr, [])
subfunc.__name__ = name # type:ignore
subfunc = jax.jit(subfunc, inline=inline, keep_unused=keep_unused)

outputs_scaled_flat = subfunc(*args)
return outputs_scaled_flat

Expand All @@ -319,11 +424,12 @@ def scaled_xla_call_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[S
from jax.interpreters.xla import xla_call_p

register_scaled_op(xla_call_p, scaled_xla_call_translation)
_scaled_jaxpr_ops_registry[xla_call_p] = scaled_xla_call_translation
except (ImportError, ModuleNotFoundError):
pass


def scaled_custom_jvp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]:
def scaled_custom_jvp_call_translation(*args: ScalifyTracerArray, **params: Any) -> Sequence[ScalifyTracerArray]:
"""Scaled translation of `custom_jvp_call` primitive. Forwarding the scaled call to sub-jaxpr,
and modifying the underlying `jvp` function.
"""
Expand All @@ -340,8 +446,11 @@ def scaled_custom_jvp_call_translation(*args: ScaledArray, **params: Any) -> Seq
register_scaled_op(custom_jvp_call_p, scaled_custom_jvp_call_translation)
register_scaled_op(custom_jvp_call_jaxpr_p, scaled_custom_jvp_call_translation)

_scaled_jaxpr_ops_registry[custom_jvp_call_p] = scaled_custom_jvp_call_translation
_scaled_jaxpr_ops_registry[custom_jvp_call_jaxpr_p] = scaled_custom_jvp_call_translation

def scaled_custom_vjp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]:

def scaled_custom_vjp_call_translation(*args: ScalifyTracerArray, **params: Any) -> Sequence[ScalifyTracerArray]:
"""Scaled translation of `custom_vjp_call` primitive. Forwarding the scaled call to sub-jaxpr,
and modifying the underlying `vjp` function.
"""
Expand All @@ -354,3 +463,6 @@ def scaled_custom_vjp_call_translation(*args: ScaledArray, **params: Any) -> Seq

register_scaled_op(custom_vjp_call_p, scaled_custom_vjp_call_translation)
register_scaled_op(custom_vjp_call_jaxpr_p, scaled_custom_vjp_call_translation)

_scaled_jaxpr_ops_registry[custom_vjp_call_p] = scaled_custom_vjp_call_translation
_scaled_jaxpr_ops_registry[custom_vjp_call_jaxpr_p] = scaled_custom_vjp_call_translation
Loading

0 comments on commit 87c195b

Please sign in to comment.