From ccc4355b36bd618aaa836e851b97474216dd070a Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Tue, 30 Jan 2024 09:22:35 +0000 Subject: [PATCH] Forwarding broadcasted scalar metadata in Scalify tracer. `scalify` interpreter/tracer is now properly tracking which tensors are just broadcasted scalars, helping then to refine the conversion rule to ScaledArray for these. In practice: it means (finally!) proper full scale propagation in MNIST training, resulting in stable training with dynamic rescale. TODO: we still need to understand why `scaled_mul` requires ScaledArray promotion to get the MNIST training example running. This requirement has been lifted in `div/add/sub` thanks to this PR. --- jax_scaled_arithmetics/core/datatype.py | 8 +- jax_scaled_arithmetics/core/interpreters.py | 114 ++++++++++++++---- .../lax/scaled_ops_common.py | 2 +- jax_scaled_arithmetics/lax/scaled_ops_l2.py | 3 +- tests/core/test_datatype.py | 13 +- tests/core/test_interpreter.py | 76 ++++++++---- tests/lax/test_numpy_integration.py | 32 +++++ 7 files changed, 191 insertions(+), 57 deletions(-) create mode 100644 tests/lax/test_numpy_integration.py diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index a3d3e71..4b31ac1 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -171,6 +171,7 @@ def as_scaled_array_base( if isinstance(val, ScaledArray): return val + assert scale is None or scale_dtype is None # Simple case => when can ignore the scaling factor (i.e. 1 implicitely). is_static_one_scale: bool = scale is None or is_static_one_scalar(scale) # type:ignore # Trivial cases: bool, int, float. @@ -189,12 +190,15 @@ def as_scaled_array_base( scale_dtype = scale_dtype or val.dtype scale = np.array(1, dtype=scale_dtype) if scale is None else scale - if isinstance(val, (np.ndarray, Array)): + if isinstance(val, (np.ndarray, *ArrayTypes)): if is_static_one_scale: return ScaledArray(val, scale) else: return ScaledArray(val / scale.astype(val.dtype), scale) # type:ignore - return scaled_array_base(val, scale) + + # TODO: fix bug when scale is not 1. + raise NotImplementedError(f"Constructing `ScaledArray` from {val} and {scale} not supported.") + # return scaled_array_base(val, scale) def as_scaled_array(val: Any, scale: Optional[ArrayLike] = None) -> ScaledArray: diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 8de426d..378e0d7 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -2,11 +2,11 @@ from dataclasses import dataclass from enum import IntEnum from functools import partial, wraps -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union import jax import numpy as np -from jax import core +from jax import core, lax from jax._src.custom_derivatives import ( custom_jvp_call_jaxpr_p, custom_jvp_call_p, @@ -82,6 +82,15 @@ class ScaledPrimitiveType(IntEnum): """ +_scalar_preserving_primitives: Set[core.Primitive] = set() +"""Scalar preserving JAX primitives + +More specifically: if all inputs are (broadcasted) scalars, then the output(s) +are broadcasted scalars. Keeping track of broadcasted scalars is allowing +proper conversion to ScaledArrays (instead of assigning default scale 1). +""" + + def _get_lax_prim(scaled_func: Any) -> core.Primitive: try: prim_name = scaled_func.__name__.replace("scaled_", "") + "_p" @@ -107,15 +116,6 @@ def _get_data(val: Any) -> Array: return val -def promote_scalar_to_scaled_array(val: Any, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray: - """Promote a scalar (Numpy, JAX, ...) to a Scaled Array. - - Note: needs to work with any input type, including JAX tracer ones. - """ - # Use `as_scaled_array` promotion rules. - return as_scaled_array_base(val, scale_dtype=scale_dtype) - - def register_scaled_op( prim: core.Primitive, scaled_func: Any, scaled_type: ScaledPrimitiveType = ScaledPrimitiveType.FORWARD ) -> None: @@ -160,15 +160,6 @@ def find_registered_scaled_op(prim: core.Primitive) -> Tuple[Any, ScaledPrimitiv return _scaled_ops_registry.get(prim, (None, ScaledPrimitiveType.NEVER)) -def promote_to_scaled_array(val, scale_dtype: Optional[DTypeLike] = None): - if isinstance(val, ScaledArray): - return val - elif np.ndim(val) == 0: - return promote_scalar_to_scaled_array(val, scale_dtype) - # No promotion rule => just return as such. - return val - - @register_pytree_node_class @dataclass(frozen=True, init=False) class ScalifyTracerArray: @@ -210,6 +201,10 @@ def tree_unflatten(cls, aux_data, children): assert len(children) == 1 return cls(children[0], aux_data[0]) + @property + def dtype(self) -> DTypeLike: + return self.array.dtype + @property def size(self) -> int: return self.array.size @@ -223,12 +218,36 @@ def is_scaled_array(self) -> bool: return isinstance(self.array, ScaledArray) def to_scaled_array(self, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray: - if self.is_scaled_array: + """(Tentatively) converting to a scaled array. + + Supporting the following cases: + - scalar array; + - broadcasted scalar array; + + Not supporting: + - bool/int dtypes; + - any other array; + + TODO: support (constant) Numpy arrays. + """ + # Already scaled array, or not a floating point dtype. + if isinstance(self.array, ScaledArray) or not np.issubdtype(self.dtype, np.floating): return self.array - # TODO: improve the logic for broadcasted scalar arrays! - return promote_to_scaled_array(self.array, scale_dtype) + + if np.ndim(self.array) == 0: + # Single value => "easy case". + return as_scaled_array_base(self.array, scale_dtype=scale_dtype) + elif self.is_broadcasted_scalar: + # Broadcasted scalar => convert as a scalar. + scalar_val = self.array.ravel()[0] + scaled_scalar = as_scaled_array_base(scalar_val, scale_dtype=scale_dtype) + return as_scaled_array_base(self.array, scale=scaled_scalar.scale) + + # No promotion rule found => just return as such. + return self.array def to_array(self) -> Array: + """Converting to a (normal) JAX/Numpy array.""" if not self.is_scaled_array: return self.array return self.array.to_array() @@ -303,6 +322,7 @@ def write(var, val: ScalifyTracerArray): # A few initial checks to make sure there is consistency. assert len(jaxpr.invars) == len(args) + assert len(jaxpr.constvars) == len(consts) safe_map(write, jaxpr.invars, args) safe_map(write, jaxpr.constvars, consts) @@ -321,12 +341,17 @@ def write(var, val: ScalifyTracerArray): any_scaled_inputs = any([v.is_scaled_array for v in invals_tracer]) # Is there a scaled primitive associated? scaled_prim_fn, scaled_prim_type = _scaled_ops_registry.get(eqn.primitive, (None, ScaledPrimitiveType.NEVER)) + # Are outputs broadcasted scalars? + are_outputs_broadcasted_scalars = ( + all([v.is_broadcasted_scalar for v in invals_tracer]) and eqn.primitive in _scalar_preserving_primitives + ) + scalify_array_init_fn = lambda v: ScalifyTracerArray(v, is_broadcasted_scalar=are_outputs_broadcasted_scalars) if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE: # Using normal JAX primitive: no scaled inputs, and not always scale rule. invals = [v.to_array() for v in invals_tracer] outvals = jaxpr_eqn_bind(eqn, invals) - outvals_tracer = list(map(ScalifyTracerArray, outvals)) + outvals_tracer = list(map(scalify_array_init_fn, outvals)) elif scaled_prim_fn is None: raise NotImplementedError( f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet." @@ -337,7 +362,7 @@ def write(var, val: ScalifyTracerArray): outvals = scaled_prim_fn(*scaled_invals, **eqn.params) if not eqn.primitive.multiple_results: outvals = [outvals] - outvals_tracer = list(map(ScalifyTracerArray, outvals)) + outvals_tracer = list(map(scalify_array_init_fn, outvals)) # Check consistency with normal JAX mode. Help catching dtype promotion errors. # NOTE: ignoring when no outputs! (e.g. debug_callback). @@ -451,3 +476,42 @@ def scaled_custom_vjp_call_translation(*args: ScalifyTracerArray, **params: Any) _scaled_jaxpr_ops_registry[custom_vjp_call_p] = scaled_custom_vjp_call_translation _scaled_jaxpr_ops_registry[custom_vjp_call_jaxpr_p] = scaled_custom_vjp_call_translation + + +# Default collection of scalar preserving JAX primitives. +_scalar_preserving_primitives |= { + lax.abs_p, + lax.acos_p, + lax.acosh_p, + lax.add_p, + lax.asin_p, + lax.asinh_p, + lax.atan_p, + lax.atan2_p, + lax.atanh_p, + lax.bitcast_convert_type_p, + lax.broadcast_in_dim_p, + lax.cbrt_p, + lax.clamp_p, + lax.convert_element_type_p, + lax.integer_pow_p, + lax.min_p, + lax.max_p, + lax.mul_p, + lax.neg_p, + lax.reduce_prod_p, + lax.reduce_sum_p, + lax.reduce_max_p, + lax.reduce_min_p, + lax.reduce_precision_p, + lax.reshape_p, + lax.rem_p, + lax.slice_p, + lax.sin_p, + lax.sinh_p, + lax.sub_p, + lax.sqrt_p, + lax.tan_p, + lax.tanh_p, + lax.transpose_p, +} diff --git a/jax_scaled_arithmetics/lax/scaled_ops_common.py b/jax_scaled_arithmetics/lax/scaled_ops_common.py index d9e6345..f69b545 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops_common.py +++ b/jax_scaled_arithmetics/lax/scaled_ops_common.py @@ -153,7 +153,7 @@ def scaled_mul(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray: @core.register_scaled_lax_op def scaled_div(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray: # TODO: understand when promotion is really required? - lhs, rhs = as_scaled_array((lhs, rhs)) # type:ignore + # lhs, rhs = as_scaled_array((lhs, rhs)) # type:ignore # TODO: investigate different rule? return ScaledArray(lhs.data / rhs.data, lhs.scale / rhs.scale) diff --git a/jax_scaled_arithmetics/lax/scaled_ops_l2.py b/jax_scaled_arithmetics/lax/scaled_ops_l2.py index c476a16..aceee44 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops_l2.py +++ b/jax_scaled_arithmetics/lax/scaled_ops_l2.py @@ -10,7 +10,6 @@ from jax_scaled_arithmetics.core import ( DTypeLike, ScaledArray, - as_scaled_array, get_autoscale_config, pow2_round, register_scaled_op, @@ -23,7 +22,7 @@ def scaled_add_sub(A: ScaledArray, B: ScaledArray, binary_op: Any) -> ScaledArray: """Scaled add/sub generic implementation.""" # TODO: understand when promotion is really required? - A, B = as_scaled_array((A, B)) # type:ignore + # A, B = as_scaled_array((A, B)) # type:ignore check_scalar_scales(A, B) A, B = promote_scale_types(A, B) assert np.issubdtype(A.scale.dtype, np.floating) diff --git a/tests/core/test_datatype.py b/tests/core/test_datatype.py index 40eddff..1d664f6 100644 --- a/tests/core/test_datatype.py +++ b/tests/core/test_datatype.py @@ -240,14 +240,21 @@ def test__as_scaled_array__unscaled_bool_int_output(self, data): output = as_scaled_array(data) assert output is data + @chex.variants(with_jit=True, without_jit=True) def test__as_scaled_array__complex_pytree(self): input = {"x": jnp.array([1, 2]), "y": jnp.array([1.0, 2]), "z": as_scaled_array(jnp.array([1.0, 2]))} - output = as_scaled_array(input) + output = self.variant(as_scaled_array)(input, scale=np.float32(2)) assert isinstance(output, dict) assert len(output) == 3 - assert output["x"] is input["x"] + + npt.assert_array_equal(output["x"], input["x"]) npt.assert_array_equal(output["y"], input["y"]) - assert output["z"] is input["z"] + npt.assert_array_equal(output["z"], input["z"]) + npt.assert_almost_equal(output["y"].scale, 2) + + if "without_jit" in self.variant.__name__: + assert output["x"] is input["x"] + assert output["z"] is input["z"] @chex.variants(with_jit=True, without_jit=True) @parameterized.parameters( diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index 37681b9..609ddbe 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -19,7 +19,7 @@ register_scaled_op, scaled_array, ) -from jax_scaled_arithmetics.core.interpreters import ScalifyTracerArray, promote_scalar_to_scaled_array +from jax_scaled_arithmetics.core.interpreters import ScalifyTracerArray class ScalifyTracerArrayTests(chex.TestCase): @@ -73,6 +73,38 @@ def test__scalify_tracer_array__flatten__proper_pytree(self): assert tracer_arr_out.is_broadcasted_scalar == tracer_arr_in.is_broadcasted_scalar npt.assert_array_equal(np.asarray(tracer_arr_out.array), np.asarray(tracer_arr_in.array)) + @parameterized.parameters( + {"input": 3.0}, + {"input": np.float32(3.0)}, + {"input": np.array(3.0)}, + {"input": jnp.array(3.0)}, + ) + def test__scalify_tracer_array__to_scaled_array__scalar_input(self, input): + scaled_val = ScalifyTracerArray(input).to_scaled_array() + assert isinstance(scaled_val, ScaledArray) + assert scaled_val.data.dtype == scaled_val.scale.dtype + # NOTE: scale is a power-of-two. + npt.assert_almost_equal(np.asarray(scaled_val), input) + + @parameterized.parameters( + {"input": np.array(3)}, + {"input": jnp.array(3)}, + {"input": np.int32(2)}, + ) + def test__scalify_tracer_array__to_scaled_array__not_promoted_input(self, input): + out = ScalifyTracerArray(input).to_scaled_array() + assert out is input + + def test__scalify_tracer_array__to_scaled_array__broadcasted_scalar_input(self): + data = np.array([5, 5], dtype=np.float16) + scaled_out = ScalifyTracerArray(data, is_broadcasted_scalar=True).to_scaled_array(scale_dtype=np.float32) + + assert isinstance(scaled_out, ScaledArray) + assert scaled_out.dtype == data.dtype + assert scaled_out.scale.dtype == np.float32 + npt.assert_almost_equal(scaled_out.scale, 4) + npt.assert_array_equal(np.asarray(scaled_out), data) + class AutoScaleInterpreterTests(chex.TestCase): def test__register_scaled_op__error_if_already_registered(self): @@ -195,6 +227,25 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn, assert scaled_out.dtype == exp_out.dtype npt.assert_array_almost_equal(scaled_out, exp_out, decimal=4) + @chex.variants(with_jit=True, without_jit=True) + def test__autoscale_decorator__promotion_broadcasted_scalar_array(self): + def fn(sa, b): + # Forcing broadcasting before the `lax.mul` + b = jax.lax.broadcast_in_dim(b, sa.shape, ()) + return sa * b + + sa = scaled_array([0.5, 1.0], np.float32(4.0), dtype=np.float32) + b = jnp.array(4.0, dtype=np.float16) + + scaled_fn = self.variant(autoscale(fn)) + sout = scaled_fn(sa, b) + expected_out = fn(np.asarray(sa), b) + + assert isinstance(sout, ScaledArray) + # Proper output scale, with `b` treated as scaled scalar. + npt.assert_equal(np.asarray(sout.scale), np.float32(16)) + npt.assert_array_equal(np.asarray(sout), expected_out) + @chex.variants(with_jit=True, without_jit=True) def test__autoscale_decorator__custom_jvp__proper_graph_transformation_and_result(self): # JAX official `jvp` example. @@ -264,29 +315,6 @@ def fn(x, y): assert isinstance(sg, ScaledArray) npt.assert_array_almost_equal(sg, g) - @parameterized.parameters( - {"input": 3.0}, - {"input": np.float32(3.0)}, - {"input": np.array(3.0)}, - {"input": jnp.array(3.0)}, - ) - def test__promote_scalar_to_scaled_array__promoted_to_scaled_array(self, input): - scaled_val = promote_scalar_to_scaled_array(input) - assert isinstance(scaled_val, ScaledArray) - assert scaled_val.data.dtype == scaled_val.scale.dtype - # NOTE: scale is a power-of-two. - npt.assert_almost_equal(np.asarray(scaled_val), input) - - @parameterized.parameters( - {"input": np.array(3)}, - {"input": jnp.array(3)}, - {"input": 3}, - {"input": np.int32(2)}, - ) - def test__promote_scalar_to_scaled_array__not_promoted_to_scaled_array(self, input): - out = promote_scalar_to_scaled_array(input) - assert out is input - def test__autoscale_config__default_values(self): cfg = get_autoscale_config() assert isinstance(cfg, AutoScaleConfig) diff --git a/tests/lax/test_numpy_integration.py b/tests/lax/test_numpy_integration.py new file mode 100644 index 0000000..350d219 --- /dev/null +++ b/tests/lax/test_numpy_integration.py @@ -0,0 +1,32 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import chex +import jax +import jax.numpy as jnp +import numpy as np + +from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array + + +class ScaledJaxNumpyFunctions(chex.TestCase): + def setUp(self): + super().setUp() + # Use random state for reproducibility! + self.rs = np.random.RandomState(42) + + @chex.variants(with_jit=True, without_jit=True) + def test__numpy_mean__proper_gradient_scale_propagation(self): + def mean_fn(x): + # Taking the square to "force" ScaledArray gradient. + # Numpy mean constant rescaling creating trouble on backward pass! + return jax.grad(lambda v: jnp.mean(v * v))(x) + + # size = 8 * 16 + input_scaled = scaled_array(self.rs.rand(8, 16).astype(np.float32), np.float32(1)) + output_grad_scaled = self.variant(autoscale(mean_fn))(input_scaled) + + assert isinstance(output_grad_scaled, ScaledArray) + # Proper scale propagation on the backward pass (rough interval) + assert np.std(output_grad_scaled.data) >= 0.25 + assert np.std(output_grad_scaled.data) <= 1.0 + # "small" scale. + assert output_grad_scaled.scale <= 0.01