Skip to content

Commit

Permalink
Support mixed normal/scaled graph in autoscale
Browse files Browse the repository at this point in the history
The `AutoScale` interpreter needs to be generalized to support mixed graph, where some tensors are still using normal JAX arrays.

It means we need some form of rules + promotions related to:
* When to use scaled primitives;
* When to automatically promote simple arrays to ScaledArray;
  • Loading branch information
balancap committed Nov 17, 2023
1 parent b5920b7 commit ca9841c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 13 deletions.
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .datatype import DTypeLike, ScaledArray, Shape, is_scaled_leaf, scaled_array # noqa: F401
from .interpreters import autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401
from .interpreters import ScaledPrimitiveType, autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401
59 changes: 48 additions & 11 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

from enum import IntEnum
from functools import wraps
from typing import Any, Dict

Expand All @@ -13,6 +13,21 @@
_scaled_ops_registry: Dict[core.Primitive, Any] = {}


class ScaledPrimitiveType(IntEnum):
"""Scale (JAX) primitive type.
This enum described the behaviour when `autoscale` is
tracing the graph.
FORWARD: Forwarding scaling => only used if scaled inputs.
Default behaviour.
ALWAYS_SCALE: Always use scaled version.
"""

FORWARD = 1
ALWAYS_SCALE = 2


def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray:
"""Get the ScaledArray corresponding to a Numpy constant.
Expand All @@ -24,19 +39,22 @@ def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray:
return ScaledArray(data=np.array(1.0, dtype=val.dtype), scale=np.copy(val))


def register_scaled_op(prim: core.Primitive, scaled_func: Any) -> None:
def register_scaled_op(
prim: core.Primitive, scaled_func: Any, scaled_type: ScaledPrimitiveType = ScaledPrimitiveType.FORWARD
) -> None:
"""Register the scaled translation of JAX primitive.
Raises an error if a scaled translation is already existing for this primitive.
Args:
prim: JAX primitive.
scaled_fund: Scaled translation of the primitive. With the same interface.
scaled_func: Scaled translation of the primitive. With the same interface.
scaled_type: Scaled primitive type => behaviour when `autoscale` tracing.
"""
assert isinstance(prim, core.Primitive)
if prim in _scaled_ops_registry:
raise KeyError(f"A scaled translation is already registered for the JAX primitive '{prim}'.")
_scaled_ops_registry[prim] = scaled_func
_scaled_ops_registry[prim] = (scaled_func, scaled_type)


def _get_lax_prim(scaled_func: Any) -> core.Primitive:
Expand All @@ -60,25 +78,43 @@ def register_scaled_lax_op(scaled_func):
Example: `scaled_mul` is matched to `jax.lax.mul_p`
"""
lax_prim = _get_lax_prim(scaled_func)
register_scaled_op(lax_prim, scaled_func)
register_scaled_op(lax_prim, scaled_func, ScaledPrimitiveType.FORWARD)
# Always return the function in the case of decorator use.
return scaled_func


def autoscale(fun):
"""`autoscale` JAX graph transformation.
The `autoscale` graph transformation works in a forwarding mode:
scaled arrays are forwarded to scaled primitives, which will generate scaled outputs.
If no inputs to a JAX primitive are scaled -> the normal primitive is then called, generating a common
JAX output array.
This behaviour is the standard one for `ScaledPrimitiveType.FORWARD` primitives.
An alternative behaviour is possible for `ScaledPrimitiveType.ALWAYS_SCALED` primitives, where the scaled
operation will always be called. A typical example is the `set_scaling` primitive.
"""

@wraps(fun)
def wrapped(*args, **kwargs):
if len(kwargs) > 0:
raise NotImplementedError("`autoscale` JAX interpreter not supporting named tensors at present.")

aval_args = safe_map(lambda x: x.aval, args)
# Get jaxpr of unscaled/normal graph. Getting output Pytree shape as well.
closed_jaxpr, outshape = jax.make_jaxpr(fun, return_shape=True)(*aval_args, **kwargs)
out_leaves, out_pytree = jax.tree_util.tree_flatten(outshape)

inputs_scaled = args
# Trace the graph & convert to scaled one.
outputs_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
outputs_scaled_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *inputs_scaled)
# Reconstruct the output Pytree, with scaled arrays.
# NOTE: this step is also handling single vs multi outputs.
assert len(out_leaves) == len(outputs_flat)
output = jax.tree_util.tree_unflatten(out_pytree, outputs_flat)
return output
assert len(out_leaves) == len(outputs_scaled_flat)
output_scaled = jax.tree_util.tree_unflatten(out_pytree, outputs_scaled_flat)
return output_scaled

return wrapped

Expand Down Expand Up @@ -111,12 +147,13 @@ def to_scaled_array(val):
assert all([isinstance(v, ScaledArray) for v in invals])
# TODO: handle `stop_scale` case? integer/boolean dtypes?

scaled_prim_fn, scaled_prim_type = _scaled_ops_registry.get(eqn.primitive, (None, None))
# Primitive is supported by `autoscale`?
if eqn.primitive not in _scaled_ops_registry:
if scaled_prim_fn is None:
raise NotImplementedError(
f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet."
)
outvals = _scaled_ops_registry[eqn.primitive](*invals, **eqn.params)
outvals = scaled_prim_fn(*invals, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
safe_map(write, eqn.outvars, outvals)
Expand Down
1 change: 0 additions & 1 deletion tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def func(x):
# Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray
assert jaxpr.invars[0].aval.shape == scaled_input.shape
assert jaxpr.invars[1].aval.shape == ()

assert jaxpr.outvars[0].aval.shape == scaled_input.shape
assert jaxpr.outvars[1].aval.shape == ()

Expand Down

0 comments on commit ca9841c

Please sign in to comment.