Skip to content

Commit

Permalink
Forwarding broadcasted scalar metadata in Scalify tracer.
Browse files Browse the repository at this point in the history
`scalify` interpreter/tracer is now properly tracking which tensors are
just broadcasted scalars, helping then to refine the conversion rule to
ScaledArray for these.

In practice: it means (finally!) proper full scale propagation in MNIST training,
resulting in stable training with dynamic rescale.

TODO: we still need to understand why `scaled_mul` requires ScaledArray promotion to
get the MNIST training example running. This requirement has been lifted in `div/add/sub`
thanks to this PR.
  • Loading branch information
balancap committed Jan 30, 2024
1 parent 8e14725 commit 48b4b98
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 52 deletions.
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def as_scaled_array_base(
if isinstance(val, ScaledArray):
return val

assert scale is None or scale_dtype is None
# Simple case => when can ignore the scaling factor (i.e. 1 implicitely).
is_static_one_scale: bool = scale is None or is_static_one_scalar(scale) # type:ignore
# Trivial cases: bool, int, float.
Expand Down
114 changes: 89 additions & 25 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from dataclasses import dataclass
from enum import IntEnum
from functools import partial, wraps
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union

import jax
import numpy as np
from jax import core
from jax import core, lax
from jax._src.custom_derivatives import (
custom_jvp_call_jaxpr_p,
custom_jvp_call_p,
Expand Down Expand Up @@ -82,6 +82,15 @@ class ScaledPrimitiveType(IntEnum):
"""


_scalar_preserving_primitives: Set[core.Primitive] = set()
"""Scalar preserving JAX primitives
More specifically: if all inputs are (broadcasted) scalars, then the output(s)
are broadcasted scalars. Keeping track of broadcasted scalars is allowing
proper conversion to ScaledArrays (instead of assigning default scale 1).
"""


def _get_lax_prim(scaled_func: Any) -> core.Primitive:
try:
prim_name = scaled_func.__name__.replace("scaled_", "") + "_p"
Expand All @@ -107,15 +116,6 @@ def _get_data(val: Any) -> Array:
return val


def promote_scalar_to_scaled_array(val: Any, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray:
"""Promote a scalar (Numpy, JAX, ...) to a Scaled Array.
Note: needs to work with any input type, including JAX tracer ones.
"""
# Use `as_scaled_array` promotion rules.
return as_scaled_array_base(val, scale_dtype=scale_dtype)


def register_scaled_op(
prim: core.Primitive, scaled_func: Any, scaled_type: ScaledPrimitiveType = ScaledPrimitiveType.FORWARD
) -> None:
Expand Down Expand Up @@ -160,15 +160,6 @@ def find_registered_scaled_op(prim: core.Primitive) -> Tuple[Any, ScaledPrimitiv
return _scaled_ops_registry.get(prim, (None, ScaledPrimitiveType.NEVER))


def promote_to_scaled_array(val, scale_dtype: Optional[DTypeLike] = None):
if isinstance(val, ScaledArray):
return val
elif np.ndim(val) == 0:
return promote_scalar_to_scaled_array(val, scale_dtype)
# No promotion rule => just return as such.
return val


@register_pytree_node_class
@dataclass(frozen=True, init=False)
class ScalifyTracerArray:
Expand Down Expand Up @@ -210,6 +201,10 @@ def tree_unflatten(cls, aux_data, children):
assert len(children) == 1
return cls(children[0], aux_data[0])

@property
def dtype(self) -> DTypeLike:
return self.array.dtype

@property
def size(self) -> int:
return self.array.size
Expand All @@ -223,12 +218,36 @@ def is_scaled_array(self) -> bool:
return isinstance(self.array, ScaledArray)

def to_scaled_array(self, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray:
if self.is_scaled_array:
"""(Tentatively) converting to a scaled array.
Supporting the following cases:
- scalar array;
- broadcasted scalar array;
Not supporting:
- bool/int dtypes;
- any other array;
TODO: support (constant) Numpy arrays.
"""
# Already scaled array, or not a floating point dtype.
if isinstance(self.array, ScaledArray) or not np.issubdtype(self.dtype, np.floating):
return self.array
# TODO: improve the logic for broadcasted scalar arrays!
return promote_to_scaled_array(self.array, scale_dtype)

if np.ndim(self.array) == 0:
# Single value => "easy case".
return as_scaled_array_base(self.array, scale_dtype=scale_dtype)
elif self.is_broadcasted_scalar:
# Broadcasted scalar => convert as a scalar.
scalar_val = self.array.ravel()[0]
scaled_scalar = as_scaled_array_base(scalar_val, scale_dtype=scale_dtype)
return as_scaled_array_base(self.array, scale=scaled_scalar.scale)

# No promotion rule found => just return as such.
return self.array

def to_array(self) -> Array:
"""Converting to a (normal) JAX/Numpy array."""
if not self.is_scaled_array:
return self.array
return self.array.to_array()
Expand Down Expand Up @@ -303,6 +322,7 @@ def write(var, val: ScalifyTracerArray):

# A few initial checks to make sure there is consistency.
assert len(jaxpr.invars) == len(args)
assert len(jaxpr.constvars) == len(consts)
safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)

Expand All @@ -321,12 +341,17 @@ def write(var, val: ScalifyTracerArray):
any_scaled_inputs = any([v.is_scaled_array for v in invals_tracer])
# Is there a scaled primitive associated?
scaled_prim_fn, scaled_prim_type = _scaled_ops_registry.get(eqn.primitive, (None, ScaledPrimitiveType.NEVER))
# Are outputs broadcasted scalars?
are_outputs_broadcasted_scalars = (
all([v.is_broadcasted_scalar for v in invals_tracer]) and eqn.primitive in _scalar_preserving_primitives
)
scalify_array_init_fn = lambda v: ScalifyTracerArray(v, is_broadcasted_scalar=are_outputs_broadcasted_scalars)

if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE:
# Using normal JAX primitive: no scaled inputs, and not always scale rule.
invals = [v.to_array() for v in invals_tracer]
outvals = jaxpr_eqn_bind(eqn, invals)
outvals_tracer = list(map(ScalifyTracerArray, outvals))
outvals_tracer = list(map(scalify_array_init_fn, outvals))
elif scaled_prim_fn is None:
raise NotImplementedError(
f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet."
Expand All @@ -337,7 +362,7 @@ def write(var, val: ScalifyTracerArray):
outvals = scaled_prim_fn(*scaled_invals, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
outvals_tracer = list(map(ScalifyTracerArray, outvals))
outvals_tracer = list(map(scalify_array_init_fn, outvals))

# Check consistency with normal JAX mode. Help catching dtype promotion errors.
# NOTE: ignoring when no outputs! (e.g. debug_callback).
Expand Down Expand Up @@ -451,3 +476,42 @@ def scaled_custom_vjp_call_translation(*args: ScalifyTracerArray, **params: Any)

_scaled_jaxpr_ops_registry[custom_vjp_call_p] = scaled_custom_vjp_call_translation
_scaled_jaxpr_ops_registry[custom_vjp_call_jaxpr_p] = scaled_custom_vjp_call_translation


# Default collection of scalar preserving JAX primitives.
_scalar_preserving_primitives |= {
lax.abs_p,
lax.acos_p,
lax.acosh_p,
lax.add_p,
lax.asin_p,
lax.asinh_p,
lax.atan_p,
lax.atan2_p,
lax.atanh_p,
lax.bitcast_convert_type_p,
lax.broadcast_in_dim_p,
lax.cbrt_p,
lax.clamp_p,
lax.convert_element_type_p,
lax.integer_pow_p,
lax.min_p,
lax.max_p,
lax.mul_p,
lax.neg_p,
lax.reduce_prod_p,
lax.reduce_sum_p,
lax.reduce_max_p,
lax.reduce_min_p,
lax.reduce_precision_p,
lax.reshape_p,
lax.rem_p,
lax.slice_p,
lax.sin_p,
lax.sinh_p,
lax.sub_p,
lax.sqrt_p,
lax.tan_p,
lax.tanh_p,
lax.transpose_p,
}
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def scaled_mul(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
@core.register_scaled_lax_op
def scaled_div(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
# TODO: understand when promotion is really required?
lhs, rhs = as_scaled_array((lhs, rhs)) # type:ignore
# lhs, rhs = as_scaled_array((lhs, rhs)) # type:ignore
# TODO: investigate different rule?
return ScaledArray(lhs.data / rhs.data, lhs.scale / rhs.scale)

Expand Down
3 changes: 1 addition & 2 deletions jax_scaled_arithmetics/lax/scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from jax_scaled_arithmetics.core import (
DTypeLike,
ScaledArray,
as_scaled_array,
get_autoscale_config,
pow2_round,
register_scaled_op,
Expand All @@ -23,7 +22,7 @@
def scaled_add_sub(A: ScaledArray, B: ScaledArray, binary_op: Any) -> ScaledArray:
"""Scaled add/sub generic implementation."""
# TODO: understand when promotion is really required?
A, B = as_scaled_array((A, B)) # type:ignore
# A, B = as_scaled_array((A, B)) # type:ignore
check_scalar_scales(A, B)
A, B = promote_scale_types(A, B)
assert np.issubdtype(A.scale.dtype, np.floating)
Expand Down
76 changes: 52 additions & 24 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
register_scaled_op,
scaled_array,
)
from jax_scaled_arithmetics.core.interpreters import ScalifyTracerArray, promote_scalar_to_scaled_array
from jax_scaled_arithmetics.core.interpreters import ScalifyTracerArray


class ScalifyTracerArrayTests(chex.TestCase):
Expand Down Expand Up @@ -73,6 +73,38 @@ def test__scalify_tracer_array__flatten__proper_pytree(self):
assert tracer_arr_out.is_broadcasted_scalar == tracer_arr_in.is_broadcasted_scalar
npt.assert_array_equal(np.asarray(tracer_arr_out.array), np.asarray(tracer_arr_in.array))

@parameterized.parameters(
{"input": 3.0},
{"input": np.float32(3.0)},
{"input": np.array(3.0)},
{"input": jnp.array(3.0)},
)
def test__scalify_tracer_array__to_scaled_array__scalar_input(self, input):
scaled_val = ScalifyTracerArray(input).to_scaled_array()
assert isinstance(scaled_val, ScaledArray)
assert scaled_val.data.dtype == scaled_val.scale.dtype
# NOTE: scale is a power-of-two.
npt.assert_almost_equal(np.asarray(scaled_val), input)

@parameterized.parameters(
{"input": np.array(3)},
{"input": jnp.array(3)},
{"input": np.int32(2)},
)
def test__scalify_tracer_array__to_scaled_array__not_promoted_input(self, input):
out = ScalifyTracerArray(input).to_scaled_array()
assert out is input

def test__scalify_tracer_array__to_scaled_array__broadcasted_scalar_input(self):
data = np.array([5, 5], dtype=np.float16)
scaled_out = ScalifyTracerArray(data, is_broadcasted_scalar=True).to_scaled_array(scale_dtype=np.float32)

assert isinstance(scaled_out, ScaledArray)
assert scaled_out.dtype == data.dtype
assert scaled_out.scale.dtype == np.float32
npt.assert_almost_equal(scaled_out.scale, 4)
npt.assert_array_equal(np.asarray(scaled_out), data)


class AutoScaleInterpreterTests(chex.TestCase):
def test__register_scaled_op__error_if_already_registered(self):
Expand Down Expand Up @@ -195,6 +227,25 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn,
assert scaled_out.dtype == exp_out.dtype
npt.assert_array_almost_equal(scaled_out, exp_out, decimal=4)

@chex.variants(with_jit=True, without_jit=True)
def test__autoscale_decorator__promotion_broadcasted_scalar_array(self):
def fn(sa, b):
# Forcing broadcasting before the `lax.mul`
b = jax.lax.broadcast_in_dim(b, sa.shape, ())
return sa * b

sa = scaled_array([0.5, 1.0], np.float32(4.0), dtype=np.float32)
b = jnp.array(4.0, dtype=np.float16)

scaled_fn = self.variant(autoscale(fn))
sout = scaled_fn(sa, b)
expected_out = fn(np.asarray(sa), b)

assert isinstance(sout, ScaledArray)
# Proper output scale, with `b` treated as scaled scalar.
npt.assert_equal(np.asarray(sout.scale), np.float32(16))
npt.assert_array_equal(np.asarray(sout), expected_out)

@chex.variants(with_jit=True, without_jit=True)
def test__autoscale_decorator__custom_jvp__proper_graph_transformation_and_result(self):
# JAX official `jvp` example.
Expand Down Expand Up @@ -264,29 +315,6 @@ def fn(x, y):
assert isinstance(sg, ScaledArray)
npt.assert_array_almost_equal(sg, g)

@parameterized.parameters(
{"input": 3.0},
{"input": np.float32(3.0)},
{"input": np.array(3.0)},
{"input": jnp.array(3.0)},
)
def test__promote_scalar_to_scaled_array__promoted_to_scaled_array(self, input):
scaled_val = promote_scalar_to_scaled_array(input)
assert isinstance(scaled_val, ScaledArray)
assert scaled_val.data.dtype == scaled_val.scale.dtype
# NOTE: scale is a power-of-two.
npt.assert_almost_equal(np.asarray(scaled_val), input)

@parameterized.parameters(
{"input": np.array(3)},
{"input": jnp.array(3)},
{"input": 3},
{"input": np.int32(2)},
)
def test__promote_scalar_to_scaled_array__not_promoted_to_scaled_array(self, input):
out = promote_scalar_to_scaled_array(input)
assert out is input

def test__autoscale_config__default_values(self):
cfg = get_autoscale_config()
assert isinstance(cfg, AutoScaleConfig)
Expand Down
32 changes: 32 additions & 0 deletions tests/lax/test_numpy_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import chex
import jax
import jax.numpy as jnp
import numpy as np

from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array


class ScaledJaxNumpyFunctions(chex.TestCase):
def setUp(self):
super().setUp()
# Use random state for reproducibility!
self.rs = np.random.RandomState(42)

@chex.variants(with_jit=True, without_jit=True)
def test__numpy_mean__proper_gradient_scale_propagation(self):
def mean_fn(x):
# Taking the square to "force" ScaledArray gradient.
# Numpy mean constant rescaling creating trouble on backward pass!
return jax.grad(lambda v: jnp.mean(v * v))(x)

# size = 8 * 16
input_scaled = scaled_array(self.rs.rand(8, 16).astype(np.float32), np.float32(1))
output_grad_scaled = self.variant(autoscale(mean_fn))(input_scaled)

assert isinstance(output_grad_scaled, ScaledArray)
# Proper scale propagation on the backward pass (rough interval)
assert np.std(output_grad_scaled.data) >= 0.25
assert np.std(output_grad_scaled.data) <= 1.0
# "small" scale.
assert output_grad_scaled.scale <= 0.01

0 comments on commit 48b4b98

Please sign in to comment.