Skip to content

Commit

Permalink
Implement AutoScale/Scalify TracerMetaArray data structure.
Browse files Browse the repository at this point in the history
Introducing the dataclass `ScalifyTracerArray` in JSA interpreter/tracer in order
to be able to pass additional metadata on the array (e.g. whether it is a broadcasted
scalar tensor).
  • Loading branch information
balancap committed Jan 29, 2024
1 parent 9b862ef commit 868e154
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 48 deletions.
189 changes: 143 additions & 46 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from enum import IntEnum
from functools import partial, wraps
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import jax
import numpy as np
Expand All @@ -14,9 +14,10 @@
custom_vjp_call_p,
)
from jax._src.util import safe_map
from jax.tree_util import register_pytree_node_class

from .datatype import Array, DTypeLike, ScaledArray, as_scaled_array_base, is_scaled_leaf
from .utils import Pow2RoundMode
from .datatype import Array, ArrayTypes, DTypeLike, ScaledArray, Shape, as_scaled_array_base, is_scaled_leaf
from .utils import Pow2RoundMode, python_scalar_as_numpy


@dataclass(frozen=True)
Expand Down Expand Up @@ -68,7 +69,17 @@ class ScaledPrimitiveType(IntEnum):
ALWAYS_SCALE = 2


_scaled_jaxpr_ops_registry: Dict[core.Primitive, Any] = {}
"""Registry of (sub) "jaxpr" ops/primitives and their scaled translation.
The core "jaxpr" primitives are typical `pjit`, `xla_call`, where the JSA interpreter
will need to be run on sub-jaxprs, passing the full metadata on input/output tensors.
"""


_scaled_ops_registry: Dict[core.Primitive, Tuple[Any, ScaledPrimitiveType]] = {}
"""Registry of JAX common primitives and their scaled translation.
"""


def _get_lax_prim(scaled_func: Any) -> core.Primitive:
Expand Down Expand Up @@ -118,6 +129,8 @@ def register_scaled_op(
scaled_type: Scaled primitive type => behaviour when `autoscale` tracing.
"""
assert isinstance(prim, core.Primitive)
# Can not register a jaxpr type op this way.
assert prim not in _scaled_jaxpr_ops_registry
if prim in _scaled_ops_registry:
raise KeyError(f"A scaled translation is already registered for the JAX primitive '{prim}'.")
_scaled_ops_registry[prim] = (scaled_func, scaled_type)
Expand Down Expand Up @@ -147,6 +160,80 @@ def find_registered_scaled_op(prim: core.Primitive) -> Tuple[Any, ScaledPrimitiv
return _scaled_ops_registry.get(prim, (None, ScaledPrimitiveType.NEVER))


def promote_to_scaled_array(val, scale_dtype: Optional[DTypeLike] = None):
if isinstance(val, ScaledArray):
return val
elif np.ndim(val) == 0:
return promote_scalar_to_scaled_array(val, scale_dtype)
# No promotion rule => just return as such.
return val


@register_pytree_node_class
@dataclass(frozen=True, init=False)
class ScalifyTracerArray:
"""Meta-Array class used in scalify tracer. It can represent
any array, scaled or not, and tracks whether an array corresponds to a scalar broadcasted.
Compatible with JAX PyTrees in order to be able to trace a graph with `ScalifyTracerArray`
as inputs/outputs.
Args:
array: Normal or scaled array.
is_broadcasted_scalar: Is the array a broadcasted scalar (metadata).
"""

array: Union[Array, ScaledArray] = None
is_broadcasted_scalar: bool = False

def __init__(self, arr: Union[Array, ScaledArray], is_broadcasted_scalar: Optional[bool] = None) -> None:
# Convert Python scalars, if necessary.
arr = python_scalar_as_numpy(arr)
assert isinstance(arr, (np.bool_, np.number, np.ndarray, ScaledArray, *ArrayTypes))
object.__setattr__(self, "array", arr)
# Optional is broadcasted scalar information.
is_scalar = self.array.size == 1
is_broadcasted_scalar = is_scalar if is_broadcasted_scalar is None else is_broadcasted_scalar or is_scalar
object.__setattr__(self, "is_broadcasted_scalar", is_broadcasted_scalar)

def tree_flatten(self):
# See official JAX documentation on extending PyTrees.
# Note: using explicit tree flatten instead of chex for MyPy compatibility.
children = (self.array,)
aux_data = (self.is_broadcasted_scalar,)
return (children, aux_data)

@classmethod
def tree_unflatten(cls, aux_data, children):
# See official JAX documentation on extending PyTrees.
assert len(aux_data) == 1
assert len(children) == 1
return cls(children[0], aux_data[0])

@property
def size(self) -> int:
return self.array.size

@property
def shape(self) -> Shape:
return self.array.shape

@property
def is_scaled_array(self) -> bool:
return isinstance(self.array, ScaledArray)

def to_scaled_array(self, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray:
if self.is_scaled_array:
return self.array
# TODO: improve the logic for broadcasted scalar arrays!
return promote_to_scaled_array(self.array, scale_dtype)

def to_array(self) -> Array:
if not self.is_scaled_array:
return self.array
return self.array.to_array()


def autoscale(fun):
"""`autoscale` JAX graph transformation.
Expand Down Expand Up @@ -174,8 +261,12 @@ def wrapped(*args, **kwargs):
# Flattening of PyTree inputs.
inputs_scaled = args
inputs_scaled_flat, _ = jax.tree_util.tree_flatten(inputs_scaled, is_leaf=is_scaled_leaf)
# Convert to Scalify tracer (meta) arrays.
inputs_tracer_flat = list(map(ScalifyTracerArray, inputs_scaled_flat))
consts_tracer_flat = list(map(ScalifyTracerArray, closed_jaxpr.literals))
# Trace the graph & convert to scaled one.
outputs_scaled_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *inputs_scaled_flat)
outputs_tracer_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, consts_tracer_flat, *inputs_tracer_flat)
outputs_scaled_flat = [v.array for v in outputs_tracer_flat]
# Reconstruct the output Pytree, with scaled arrays.
# NOTE: this step is also handling single vs multi outputs.
assert len(out_leaves) == len(outputs_scaled_flat)
Expand All @@ -185,66 +276,73 @@ def wrapped(*args, **kwargs):
return wrapped


def autoscale_jaxpr(jaxpr: core.Jaxpr, consts, *args):
env: Dict[core.Var, ScaledArray] = {}
def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Sequence[core.ShapedArray]:
"""Bind a Jaxpr equation to arrays."""
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
return outvals


def autoscale_jaxpr(jaxpr: core.Jaxpr, consts: Sequence[ScalifyTracerArray], *args: ScalifyTracerArray):
env: Dict[core.Var, ScalifyTracerArray] = {}
# Check dtype consistency between normal and scaled modes.
safe_check_dtypes: bool = False
# AutoScale config to use.
autoscale_cfg = get_autoscale_config()

def read(var):
def read(var) -> ScalifyTracerArray:
if type(var) is core.Literal:
return var.val
# Wrap the constant in tracer array.
return ScalifyTracerArray(var.val)
return env[var]

def write(var, val):
def write(var, val: ScalifyTracerArray):
env[var] = val

def promote_to_scaled_array(val, scale_dtype):
if isinstance(val, ScaledArray):
return val
elif np.ndim(val) == 0:
return promote_scalar_to_scaled_array(val, scale_dtype)
# No promotion rule => just return as such.
return val

def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Sequence[core.ShapedArray]:
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
return outvals

# A few initial checks to make sure there is consistency.
assert len(jaxpr.invars) == len(args)
safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)

for eqn in jaxpr.eqns:
invals = safe_map(read, eqn.invars)
invals_tracer: List[ScalifyTracerArray] = safe_map(read, eqn.invars)
if eqn.primitive in _scaled_jaxpr_ops_registry:
# Core sub-jaxpr primitive => pass the complete tracer array with metadata.
scaled_jaxpr_prim_fn = _scaled_jaxpr_ops_registry[eqn.primitive]
outvals_tracer = scaled_jaxpr_prim_fn(*invals_tracer, **eqn.params)
# Save outputs and move on!
safe_map(write, eqn.outvars, outvals_tracer)
continue

# Common (scaled) JAX primitives path.
# Is there any ScaledArray among inputs?
any_scaled_inputs = any([isinstance(v, ScaledArray) for v in invals])
any_scaled_inputs = any([v.is_scaled_array for v in invals_tracer])
# Is there a scaled primitive associated?
scaled_prim_fn, scaled_prim_type = _scaled_ops_registry.get(eqn.primitive, (None, ScaledPrimitiveType.NEVER))

if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE:
# Using normal JAX primitive: no scaled inputs, and not always scale rule.
invals = [v.to_array() for v in invals_tracer]
outvals = jaxpr_eqn_bind(eqn, invals)
outvals_tracer = list(map(ScalifyTracerArray, outvals))
elif scaled_prim_fn is None:
raise NotImplementedError(
f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet."
)
else:
# Using scaled primitive. Automatic promotion of inputs to scaled array, when possible.
scaled_invals = list(map(lambda v: promote_to_scaled_array(v, autoscale_cfg.scale_dtype), invals))
scaled_invals = [v.to_scaled_array(autoscale_cfg.scale_dtype) for v in invals_tracer]
outvals = scaled_prim_fn(*scaled_invals, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
outvals_tracer = list(map(ScalifyTracerArray, outvals))

# Check consistency with normal JAX mode. Help catching dtype promotion errors.
# NOTE: ignoring when no outputs! (e.g. debug_callback).
if safe_check_dtypes and len(outvals) > 0:
ref_outvals = jaxpr_eqn_bind(eqn, [_get_data(v) for v in invals])
ref_outvals = jaxpr_eqn_bind(eqn, [_get_data(v.array) for v in invals_tracer])
data_outvals = [_get_data(v) for v in outvals]
# Check scaled dtypes == ref dtypes.
ref_dtypes = tuple(v.dtype for v in ref_outvals)
Expand All @@ -254,13 +352,13 @@ def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Se
f"Output dtype of '{eqn.primitive}' scaled translation is not consistent with the JAX reference primitive implementation: {data_dtypes} vs {ref_dtypes}."
)

safe_map(write, eqn.outvars, outvals)
safe_map(write, eqn.outvars, outvals_tracer)

outvals = safe_map(read, jaxpr.outvars)
return outvals
outvals_tracer = safe_map(read, jaxpr.outvars)
return outvals_tracer


def scaled_pjit_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[ScaledArray]:
def scaled_pjit_translation(*args: ScalifyTracerArray, **kwargs: Any) -> Sequence[ScalifyTracerArray]:
"""Scaled translation of `pjit`. Basically re-running `autoscale` on sub-jaxpr.
NOTE: the `pjit` call will be kept, forwarding the proper parameters (shardings, ...).
Expand All @@ -274,24 +372,24 @@ def scaled_pjit_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[Scale
# in_shardings = kwargs["in_shardings"]
# out_shardings = kwargs["out_shardings"]

consts_tracer_flat = [ScalifyTracerArray(v) for v in closed_jaxpr.literals]
# Generate the sub-scaled function, with proper `jax.jit` options.
subfunc = partial(autoscale_jaxpr, closed_jaxpr.jaxpr, closed_jaxpr.literals)
subfunc = partial(autoscale_jaxpr, closed_jaxpr.jaxpr, consts_tracer_flat)
subfunc.__name__ = name # type:ignore
subfunc = jax.jit(subfunc, inline=inline, keep_unused=keep_unused)

outputs_scaled_flat = subfunc(*args)
return outputs_scaled_flat
outvals = subfunc(*args)
return outvals


try:
from jax._src.pjit import pjit_p

register_scaled_op(pjit_p, scaled_pjit_translation)
_scaled_jaxpr_ops_registry[pjit_p] = scaled_pjit_translation
except (ImportError, ModuleNotFoundError):
pass


def scaled_xla_call_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[ScaledArray]:
def scaled_xla_call_translation(*args: ScalifyTracerArray, **kwargs: Any) -> Sequence[ScalifyTracerArray]:
"""Scaled translation of `xla_call`. Basically re-running `autoscale` on sub-jaxpr.
Useful for JAX 0.3 compatibility
Expand All @@ -310,20 +408,19 @@ def scaled_xla_call_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[S
subfunc = partial(autoscale_jaxpr, jaxpr, [])
subfunc.__name__ = name # type:ignore
subfunc = jax.jit(subfunc, inline=inline, keep_unused=keep_unused)

outputs_scaled_flat = subfunc(*args)
return outputs_scaled_flat


try:
from jax.interpreters.xla import xla_call_p

register_scaled_op(xla_call_p, scaled_xla_call_translation)
_scaled_jaxpr_ops_registry[xla_call_p] = scaled_xla_call_translation
except (ImportError, ModuleNotFoundError):
pass


def scaled_custom_jvp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]:
def scaled_custom_jvp_call_translation(*args: ScalifyTracerArray, **params: Any) -> Sequence[ScalifyTracerArray]:
"""Scaled translation of `custom_jvp_call` primitive. Forwarding the scaled call to sub-jaxpr,
and modifying the underlying `jvp` function.
"""
Expand All @@ -337,11 +434,11 @@ def scaled_custom_jvp_call_translation(*args: ScaledArray, **params: Any) -> Seq
return call_subfunc(*args)


register_scaled_op(custom_jvp_call_p, scaled_custom_jvp_call_translation)
register_scaled_op(custom_jvp_call_jaxpr_p, scaled_custom_jvp_call_translation)
_scaled_jaxpr_ops_registry[custom_jvp_call_p] = scaled_custom_jvp_call_translation
_scaled_jaxpr_ops_registry[custom_jvp_call_jaxpr_p] = scaled_custom_jvp_call_translation


def scaled_custom_vjp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]:
def scaled_custom_vjp_call_translation(*args: ScalifyTracerArray, **params: Any) -> Sequence[ScalifyTracerArray]:
"""Scaled translation of `custom_vjp_call` primitive. Forwarding the scaled call to sub-jaxpr,
and modifying the underlying `vjp` function.
"""
Expand All @@ -352,5 +449,5 @@ def scaled_custom_vjp_call_translation(*args: ScaledArray, **params: Any) -> Seq
return call_subfunc(*args)


register_scaled_op(custom_vjp_call_p, scaled_custom_vjp_call_translation)
register_scaled_op(custom_vjp_call_jaxpr_p, scaled_custom_vjp_call_translation)
_scaled_jaxpr_ops_registry[custom_vjp_call_p] = scaled_custom_vjp_call_translation
_scaled_jaxpr_ops_registry[custom_vjp_call_jaxpr_p] = scaled_custom_vjp_call_translation
16 changes: 16 additions & 0 deletions jax_scaled_arithmetics/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,19 @@ def safe_reciprocal(val: Array) -> Array:
return np.reciprocal(val, out=np.array(0, dtype=val.dtype), where=val != 0)
# JAX general implementation.
return jax.lax.select(val == 0, val, jax.lax.reciprocal(val))


def python_scalar_as_numpy(val: Any) -> Any:
"""Convert Python scalar to Numpy scalar, if possible.
Using by default JAX 32 bits precision, instead of 64 bits.
Returning unchanged value if not any (bool, int, float).
"""
if isinstance(val, bool):
return np.bool_(val)
elif isinstance(val, int):
return np.int32(val)
elif isinstance(val, float):
return np.float32(val)
return val
Loading

0 comments on commit 868e154

Please sign in to comment.