Skip to content

Commit

Permalink
Getting MNIST training working properly with jax.nn.relu (#98)
Browse files Browse the repository at this point in the history
This PR is fixing backward propagation of in `jax.nn.relu`, properly handling `lax.select` and `lax.full_like`
scale propagation. As a consequence, JAX scipy `logsumexp` scale propagation is now working properly.
  • Loading branch information
balancap authored Feb 2, 2024
1 parent d50d13e commit 638e9f9
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 21 deletions.
5 changes: 3 additions & 2 deletions experiments/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,20 @@ def predict(params, inputs):
for w, b in params[:-1]:
# Matmul + relu
outputs = jnp.dot(activations, w) + b
activations = jnp.maximum(outputs, 0)
activations = jax.nn.relu(outputs)

final_w, final_b = params[-1]
logits = jnp.dot(activations, final_w) + final_b
# Dynamic rescaling of the gradient, as logits gradient not properly scaled.
logits = jsa.ops.dynamic_rescale_l2_grad(logits)
# logits = jsa.ops.dynamic_rescale_l2_grad(logits)
logits = logits - logsumexp(logits, axis=1, keepdims=True)
return logits


def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
targets = jsa.lax.rebalance(targets, np.float32(1 / 16))
return -jnp.mean(jnp.sum(preds * targets, axis=1))


Expand Down
3 changes: 3 additions & 0 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class ScaledArray:
scale: GenericArray

def __post_init__(self):
# Always have a Numpy array as `data`.
if isinstance(self.data, np.number):
object.__setattr__(self, "data", np.array(self.data))
# TODO/FIXME: support number as data?
assert isinstance(self.data, (*ArrayTypes, np.ndarray))
assert isinstance(self.scale, (*ArrayTypes, np.ndarray, np.number))
Expand Down
112 changes: 99 additions & 13 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@
from jax._src.util import safe_map
from jax.tree_util import register_pytree_node_class

from .datatype import Array, ArrayTypes, DTypeLike, ScaledArray, Shape, as_scaled_array_base, is_scaled_leaf
from .datatype import (
Array,
ArrayTypes,
DTypeLike,
ScaledArray,
Shape,
as_scaled_array_base,
is_scaled_leaf,
is_static_zero,
)
from .utils import Pow2RoundMode, python_scalar_as_numpy


Expand Down Expand Up @@ -82,14 +91,22 @@ class ScaledPrimitiveType(IntEnum):
"""


_scalar_preserving_primitives: Set[core.Primitive] = set()
"""Scalar preserving JAX primitives
_broadcasted_scalar_preserving_primitives: Set[core.Primitive] = set()
"""Scalar (broadcasted array) preserving JAX primitives
More specifically: if all inputs are (broadcasted) scalars, then the output(s)
are broadcasted scalars. Keeping track of broadcasted scalars is allowing
proper conversion to ScaledArrays (instead of assigning default scale 1).
"""

_broadcasted_zero_preserving_primitives: Set[core.Primitive] = set()
"""Zero (broadcasted array) preserving JAX primitives
More specifically: if all inputs are (broadcasted) zero scalars, then the output(s)
are broadcasted zero scalars. Keeping track of broadcasted zero scalars is allowing
proper conversion to ScaledArrays (instead of assigning default scale 1).
"""


def _get_lax_prim(scaled_func: Any) -> core.Primitive:
try:
Expand Down Expand Up @@ -172,34 +189,52 @@ class ScalifyTracerArray:
Args:
array: Normal or scaled array.
is_broadcasted_scalar: Is the array a broadcasted scalar (metadata).
is_broadcasted_zero: Is the array a broadcased zero scalar (metadata).
"""

array: Union[Array, ScaledArray] = None
# Metadata fields.
is_broadcasted_scalar: bool = False

def __init__(self, arr: Union[Array, ScaledArray], is_broadcasted_scalar: Optional[bool] = None) -> None:
is_broadcasted_zero: bool = False

def __init__(
self,
arr: Union[Array, ScaledArray],
is_broadcasted_scalar: Optional[bool] = None,
is_broadcasted_zero: Optional[bool] = None,
) -> None:
# Convert Python scalars, if necessary.
arr = python_scalar_as_numpy(arr)
assert isinstance(arr, (np.bool_, np.number, np.ndarray, ScaledArray, *ArrayTypes))
object.__setattr__(self, "array", arr)
# Optional is broadcasted scalar information.

# Is a zero broadcasted scalar? Only checking when info is not provided.
if is_broadcasted_zero is None:
is_broadcasted_zero = bool(np.all(is_static_zero(self.array)))
object.__setattr__(self, "is_broadcasted_zero", is_broadcasted_zero)
# Optional is broadcasted scalar information (always consistent with broadcasted zero!)
is_scalar = self.array.size == 1
is_broadcasted_scalar = is_scalar if is_broadcasted_scalar is None else is_broadcasted_scalar or is_scalar
is_broadcasted_scalar = is_broadcasted_scalar or is_broadcasted_zero
object.__setattr__(self, "is_broadcasted_scalar", is_broadcasted_scalar)

# Always make sure we have zero scale to represent broadcasted zero.
if is_broadcasted_zero and isinstance(self.array, ScaledArray):
object.__setattr__(self.array, "scale", np.array(0, self.array.scale.dtype))

def tree_flatten(self):
# See official JAX documentation on extending PyTrees.
# Note: using explicit tree flatten instead of chex for MyPy compatibility.
children = (self.array,)
aux_data = (self.is_broadcasted_scalar,)
aux_data = (self.is_broadcasted_scalar, self.is_broadcasted_zero)
return (children, aux_data)

@classmethod
def tree_unflatten(cls, aux_data, children):
# See official JAX documentation on extending PyTrees.
assert len(aux_data) == 1
assert len(aux_data) == 2
assert len(children) == 1
return cls(children[0], aux_data[0])
return cls(children[0], is_broadcasted_scalar=aux_data[0], is_broadcasted_zero=aux_data[1])

@property
def dtype(self) -> DTypeLike:
Expand Down Expand Up @@ -234,7 +269,12 @@ def to_scaled_array(self, scale_dtype: Optional[DTypeLike] = None) -> ScaledArra
if isinstance(self.array, ScaledArray) or not np.issubdtype(self.dtype, np.floating):
return self.array

if np.ndim(self.array) == 0:
if self.is_broadcasted_zero:
# Directly create the scaled array.
scale_dtype = scale_dtype or self.array.dtype
scale = np.array(0, dtype=scale_dtype)
return ScaledArray(data=self.array, scale=scale)
elif np.ndim(self.array) == 0:
# Single value => "easy case".
return as_scaled_array_base(self.array, scale_dtype=scale_dtype)
elif self.is_broadcasted_scalar:
Expand Down Expand Up @@ -341,11 +381,21 @@ def write(var, val: ScalifyTracerArray):
any_scaled_inputs = any([v.is_scaled_array for v in invals_tracer])
# Is there a scaled primitive associated?
scaled_prim_fn, scaled_prim_type = _scaled_ops_registry.get(eqn.primitive, (None, ScaledPrimitiveType.NEVER))

# Are outputs broadcasted scalars?
are_inputs_broadcasted_scalars = all([v.is_broadcasted_scalar for v in invals_tracer])
are_outputs_broadcasted_scalars = (
all([v.is_broadcasted_scalar for v in invals_tracer]) and eqn.primitive in _scalar_preserving_primitives
are_inputs_broadcasted_scalars and eqn.primitive in _broadcasted_scalar_preserving_primitives
)
# Are outputs broadcasted zeroes?
are_inputs_broadcasted_zeroes = all([v.is_broadcasted_zero for v in invals_tracer])
are_outputs_broadcasted_zeroes = (
are_inputs_broadcasted_zeroes and eqn.primitive in _broadcasted_zero_preserving_primitives
)
# Outputs scalify factory method.
scalify_array_init_fn = lambda v: ScalifyTracerArray(
v, is_broadcasted_scalar=are_outputs_broadcasted_scalars, is_broadcasted_zero=are_outputs_broadcasted_zeroes
)
scalify_array_init_fn = lambda v: ScalifyTracerArray(v, is_broadcasted_scalar=are_outputs_broadcasted_scalars)

if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE:
# Using normal JAX primitive: no scaled inputs, and not always scale rule.
Expand Down Expand Up @@ -479,7 +529,7 @@ def scaled_custom_vjp_call_translation(*args: ScalifyTracerArray, **params: Any)


# Default collection of scalar preserving JAX primitives.
_scalar_preserving_primitives |= {
_broadcasted_scalar_preserving_primitives |= {
lax.abs_p,
lax.acos_p,
lax.acosh_p,
Expand All @@ -492,8 +542,12 @@ def scaled_custom_vjp_call_translation(*args: ScalifyTracerArray, **params: Any)
lax.bitcast_convert_type_p,
lax.broadcast_in_dim_p,
lax.cbrt_p,
lax.ceil_p,
lax.clamp_p,
lax.convert_element_type_p,
lax.exp_p,
lax.expm1_p,
lax.floor_p,
lax.integer_pow_p,
lax.min_p,
lax.max_p,
Expand All @@ -515,3 +569,35 @@ def scaled_custom_vjp_call_translation(*args: ScalifyTracerArray, **params: Any)
lax.tanh_p,
lax.transpose_p,
}

# Default collection of zero arrays preserving JAX primitives.
_broadcasted_zero_preserving_primitives |= {
lax.abs_p,
lax.add_p,
lax.broadcast_in_dim_p,
lax.cbrt_p,
lax.ceil_p,
lax.clamp_p,
lax.convert_element_type_p,
lax.exp_p,
lax.expm1_p,
lax.floor_p,
lax.min_p,
lax.max_p,
lax.mul_p,
lax.neg_p,
lax.reduce_prod_p,
lax.reduce_sum_p,
lax.reduce_max_p,
lax.reduce_min_p,
lax.reduce_precision_p,
lax.reshape_p,
lax.slice_p,
lax.sin_p,
lax.sinh_p,
lax.sub_p,
lax.sqrt_p,
lax.tan_p,
lax.tanh_p,
lax.transpose_p,
}
18 changes: 16 additions & 2 deletions jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
from .base_scaling_primitives import scaled_set_scaling


def _get_data(val: Any) -> Array:
if isinstance(val, ScaledArray):
return val.data
return val


def check_scalar_scales(*args: ScaledArray):
"""Check all ScaledArrays have scalar scaling."""
for val in args:
Expand Down Expand Up @@ -283,8 +289,16 @@ def scaled_log(val: ScaledArray) -> ScaledArray:
@core.register_scaled_lax_op
def scaled_select_n(which: Array, *cases: ScaledArray) -> ScaledArray:
outscale_dtype = promote_types(*[get_scale_dtype(v) for v in cases])
outscale = np.array(1, dtype=outscale_dtype)
return scaled_op_default_translation(lax.select_n_p, [which, *cases], outscale=outscale)
# Get the max scale for renormalizing.
# TODO: use `get_data_scale` primitive.
scale_cases = [v.scale if isinstance(v, ScaledArray) else np.array(1, outscale_dtype) for v in cases]
scales_arr = jnp.array(scale_cases)
outscale = jnp.max(scales_arr)
# `data` components, renormalized.
data_cases = [_get_data(v) * (s / outscale).astype(v.dtype) for s, v in zip(scale_cases, cases)]
# Apply normal select!
outdata = lax.select_n(which, *data_cases)
return ScaledArray(outdata, outscale)


@core.register_scaled_lax_op
Expand Down
31 changes: 30 additions & 1 deletion tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test__scalify_tracer_array__init__from_python_value(self, arr):
assert tracer_arr.array == arr
assert not tracer_arr.is_scaled_array
assert tracer_arr.is_broadcasted_scalar == (tracer_arr.size == 1)
assert not tracer_arr.is_broadcasted_zero
assert tracer_arr.to_array() is tracer_arr.array

@parameterized.parameters(
Expand All @@ -45,32 +46,60 @@ def test__scalify_tracer_array__init__from_normal_array(self, arr):
assert tracer_arr.array is arr
assert not tracer_arr.is_scaled_array
assert tracer_arr.is_broadcasted_scalar == (tracer_arr.size == 1)
assert not tracer_arr.is_broadcasted_zero
assert tracer_arr.to_array() is tracer_arr.array
# Basic properties.
assert tracer_arr.shape == arr.shape
assert tracer_arr.size == arr.size

@parameterized.parameters(
{"arr": np.float32(2), "expected_is_zero": False},
{"arr": np.float32(0), "expected_is_zero": True},
{"arr": np.array([0, 0]), "expected_is_zero": True},
{"arr": np.array([0.0, 0.0]), "expected_is_zero": True},
{"arr": scaled_array([1, 2.0], 0.0, npapi=np), "expected_is_zero": True},
{"arr": scaled_array([0, 0.0], 1.0, npapi=np), "expected_is_zero": True},
{"arr": jnp.array([0, 0]), "expected_is_zero": False},
)
def test__scalify_tracer_array__init__zero_broadcasted_array(self, arr, expected_is_zero):
tracer_arr = ScalifyTracerArray(arr)
assert tracer_arr.is_broadcasted_zero == expected_is_zero
# Scaled array conversion => scale should be zero.
scaled_arr = tracer_arr.to_scaled_array()
if tracer_arr.is_broadcasted_zero and isinstance(scaled_arr, ScaledArray):
assert scaled_arr.scale == 0

@parameterized.parameters({"arr": scaled_array([1, 2], 3.0)})
def test__scalify_tracer_array__init__from_scaled_array(self, arr):
tracer_arr = ScalifyTracerArray(arr)
assert tracer_arr.array is arr
assert tracer_arr.is_scaled_array
assert tracer_arr.to_scaled_array() is tracer_arr.array
assert not tracer_arr.is_broadcasted_zero

def test__scalify_tracer_array__init__is_broadcasted_scalar_kwarg(self):
arr = scaled_array([1, 2], 3.0)
assert ScalifyTracerArray(arr, is_broadcasted_scalar=True).is_broadcasted_scalar
assert not ScalifyTracerArray(arr, is_broadcasted_scalar=False).is_broadcasted_scalar

def test__scalify_tracer_array__init__is_broadcasted_zero_kwarg(self):
arr = scaled_array([0, 1], 3.0)
# NOTE: explicitly passing the argument, not checking the data!
assert ScalifyTracerArray(arr, is_broadcasted_zero=True).is_broadcasted_scalar
assert ScalifyTracerArray(arr, is_broadcasted_zero=True).is_broadcasted_zero
assert not ScalifyTracerArray(arr, is_broadcasted_zero=False).is_broadcasted_scalar
assert not ScalifyTracerArray(arr, is_broadcasted_zero=False).is_broadcasted_zero

def test__scalify_tracer_array__flatten__proper_pytree(self):
arr = scaled_array([1, 2], 3.0)
tracer_arr_in = ScalifyTracerArray(arr, True)
tracer_arr_in = ScalifyTracerArray(arr, is_broadcasted_scalar=True, is_broadcasted_zero=True)
# Proper round trip!
flat_arrays, pytree = jax.tree_util.tree_flatten(tracer_arr_in)
tracer_arr_out = jax.tree_util.tree_unflatten(pytree, flat_arrays)

assert isinstance(tracer_arr_out, ScalifyTracerArray)
assert tracer_arr_out.is_broadcasted_scalar == tracer_arr_in.is_broadcasted_scalar
assert tracer_arr_out.is_broadcasted_zero == tracer_arr_in.is_broadcasted_zero
npt.assert_array_equal(np.asarray(tracer_arr_out.array), np.asarray(tracer_arr_in.array))

@parameterized.parameters(
Expand Down
21 changes: 19 additions & 2 deletions tests/lax/test_scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,26 @@ def test__scaled_boolean_binary_op__proper_result(self, bool_prim):
def test__scaled_select_n__proper_result(self):
mask = self.rs.rand(5) > 0.5
lhs = scaled_array(self.rs.rand(5), 2.0, dtype=np.float32)
rhs = scaled_array(self.rs.rand(5), 3.0, dtype=np.float32)
rhs = scaled_array(self.rs.rand(5), 4.0, dtype=np.float32)
out = scaled_select_n(mask, lhs, rhs)
assert isinstance(out, ScaledArray)
assert out.dtype == np.float32
npt.assert_almost_equal(out.scale, 1) # FIXME!
# Max scale used.
npt.assert_almost_equal(out.scale, 4)
npt.assert_array_equal(out, np.where(mask, rhs, lhs))

@parameterized.parameters(
{"scale": 0.25},
{"scale": 8.0},
)
def test__scaled_select__relu_grad_example(self, scale):
@autoscale
def relu_grad(g):
return lax.select(g > 0, g, lax.full_like(g, 0))

# Gradient with some scale.
gin = scaled_array([1.0, 0.5], np.float32(scale), dtype=np.float32)
gout = relu_grad(gin)
# Same scale should be propagated to gradient output.
assert isinstance(gout, ScaledArray)
npt.assert_array_equal(gout.scale, gin.scale)
4 changes: 3 additions & 1 deletion tests/lax/test_scipy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ def fn(a):
def test__scipy_logsumexp__accurate_scaled_op(self, dtype):
from jax.scipy.special import logsumexp

input_scaled = scaled_array(self.rs.rand(10), 2, dtype=dtype)
input_scaled = scaled_array(self.rs.rand(10), 4.0, dtype=dtype)
# JAX `logsumexp` Jaxpr is a non-trivial graph!
out_scaled = autoscale(logsumexp)(input_scaled)
out_expected = logsumexp(np.asarray(input_scaled))
assert out_scaled.dtype == out_expected.dtype
# Proper accuracy + keep the same scale.
npt.assert_array_equal(out_scaled.scale, input_scaled.scale)
npt.assert_array_almost_equal(out_scaled, out_expected, decimal=5)

0 comments on commit 638e9f9

Please sign in to comment.