diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scaled_arithmetics/core/__init__.py index 31e9aec..4d66c5b 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scaled_arithmetics/core/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from .datatype import DTypeLike, ScaledArray, Shape, is_scaled_leaf, scaled_array # noqa: F401 -from .interpreters import autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401 +from .interpreters import ScaledPrimitiveType, autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401 diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 6e46c13..a1857ca 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -1,5 +1,5 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. - +from enum import IntEnum from functools import wraps from typing import Any, Dict @@ -13,6 +13,21 @@ _scaled_ops_registry: Dict[core.Primitive, Any] = {} +class ScaledPrimitiveType(IntEnum): + """Scale (JAX) primitive type. + + This enum described the behaviour when `autoscale` is + tracing the graph. + + FORWARD: Forwarding scaling => only used if scaled inputs. + Default behaviour. + ALWAYS_SCALE: Always use scaled version. + """ + + FORWARD = 1 + ALWAYS_SCALE = 2 + + def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray: """Get the ScaledArray corresponding to a Numpy constant. @@ -24,19 +39,22 @@ def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray: return ScaledArray(data=np.array(1.0, dtype=val.dtype), scale=np.copy(val)) -def register_scaled_op(prim: core.Primitive, scaled_func: Any) -> None: +def register_scaled_op( + prim: core.Primitive, scaled_func: Any, scaled_type: ScaledPrimitiveType = ScaledPrimitiveType.FORWARD +) -> None: """Register the scaled translation of JAX primitive. Raises an error if a scaled translation is already existing for this primitive. Args: prim: JAX primitive. - scaled_fund: Scaled translation of the primitive. With the same interface. + scaled_func: Scaled translation of the primitive. With the same interface. + scaled_type: Scaled primitive type => behaviour when `autoscale` tracing. """ assert isinstance(prim, core.Primitive) if prim in _scaled_ops_registry: raise KeyError(f"A scaled translation is already registered for the JAX primitive '{prim}'.") - _scaled_ops_registry[prim] = scaled_func + _scaled_ops_registry[prim] = (scaled_func, scaled_type) def _get_lax_prim(scaled_func: Any) -> core.Primitive: @@ -60,25 +78,43 @@ def register_scaled_lax_op(scaled_func): Example: `scaled_mul` is matched to `jax.lax.mul_p` """ lax_prim = _get_lax_prim(scaled_func) - register_scaled_op(lax_prim, scaled_func) + register_scaled_op(lax_prim, scaled_func, ScaledPrimitiveType.FORWARD) # Always return the function in the case of decorator use. return scaled_func def autoscale(fun): + """`autoscale` JAX graph transformation. + + The `autoscale` graph transformation works in a forwarding mode: + scaled arrays are forwarded to scaled primitives, which will generate scaled outputs. + + If no inputs to a JAX primitive are scaled -> the normal primitive is then called, generating a common + JAX output array. + + This behaviour is the standard one for `ScaledPrimitiveType.FORWARD` primitives. + An alternative behaviour is possible for `ScaledPrimitiveType.ALWAYS_SCALED` primitives, where the scaled + operation will always be called. A typical example is the `set_scaling` primitive. + """ + @wraps(fun) def wrapped(*args, **kwargs): + if len(kwargs) > 0: + raise NotImplementedError("`autoscale` JAX interpreter not supporting named tensors at present.") + aval_args = safe_map(lambda x: x.aval, args) # Get jaxpr of unscaled/normal graph. Getting output Pytree shape as well. closed_jaxpr, outshape = jax.make_jaxpr(fun, return_shape=True)(*aval_args, **kwargs) out_leaves, out_pytree = jax.tree_util.tree_flatten(outshape) + + inputs_scaled = args # Trace the graph & convert to scaled one. - outputs_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args) + outputs_scaled_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *inputs_scaled) # Reconstruct the output Pytree, with scaled arrays. # NOTE: this step is also handling single vs multi outputs. - assert len(out_leaves) == len(outputs_flat) - output = jax.tree_util.tree_unflatten(out_pytree, outputs_flat) - return output + assert len(out_leaves) == len(outputs_scaled_flat) + output_scaled = jax.tree_util.tree_unflatten(out_pytree, outputs_scaled_flat) + return output_scaled return wrapped @@ -111,12 +147,13 @@ def to_scaled_array(val): assert all([isinstance(v, ScaledArray) for v in invals]) # TODO: handle `stop_scale` case? integer/boolean dtypes? + scaled_prim_fn, scaled_prim_type = _scaled_ops_registry.get(eqn.primitive, (None, None)) # Primitive is supported by `autoscale`? - if eqn.primitive not in _scaled_ops_registry: + if scaled_prim_fn is None: raise NotImplementedError( f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet." ) - outvals = _scaled_ops_registry[eqn.primitive](*invals, **eqn.params) + outvals = scaled_prim_fn(*invals, **eqn.params) if not eqn.primitive.multiple_results: outvals = [outvals] safe_map(write, eqn.outvars, outvals) diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index bb21b2d..be5a023 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -27,7 +27,6 @@ def func(x): # Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray assert jaxpr.invars[0].aval.shape == scaled_input.shape assert jaxpr.invars[1].aval.shape == () - assert jaxpr.outvars[0].aval.shape == scaled_input.shape assert jaxpr.outvars[1].aval.shape == ()