From e3874d18404b61ce08e03034f507d2e0c29a7d5d Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 2 Feb 2024 12:02:57 +0000 Subject: [PATCH] wip --- jax_scaled_arithmetics/core/datatype.py | 3 + jax_scaled_arithmetics/core/interpreters.py | 112 ++++++++++++++++-- .../lax/scaled_ops_common.py | 18 ++- tests/core/test_interpreter.py | 31 ++++- tests/lax/test_scaled_ops_common.py | 7 +- tests/lax/test_scipy_integration.py | 4 +- 6 files changed, 156 insertions(+), 19 deletions(-) diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index 4b31ac1..48655e4 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -43,6 +43,9 @@ class ScaledArray: scale: GenericArray def __post_init__(self): + # Always have a Numpy array as `data`. + if isinstance(self.data, np.number): + object.__setattr__(self, "data", np.array(self.data)) # TODO/FIXME: support number as data? assert isinstance(self.data, (*ArrayTypes, np.ndarray)) assert isinstance(self.scale, (*ArrayTypes, np.ndarray, np.number)) diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 378e0d7..c08f849 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -16,7 +16,16 @@ from jax._src.util import safe_map from jax.tree_util import register_pytree_node_class -from .datatype import Array, ArrayTypes, DTypeLike, ScaledArray, Shape, as_scaled_array_base, is_scaled_leaf +from .datatype import ( + Array, + ArrayTypes, + DTypeLike, + ScaledArray, + Shape, + as_scaled_array_base, + is_scaled_leaf, + is_static_zero, +) from .utils import Pow2RoundMode, python_scalar_as_numpy @@ -82,14 +91,22 @@ class ScaledPrimitiveType(IntEnum): """ -_scalar_preserving_primitives: Set[core.Primitive] = set() -"""Scalar preserving JAX primitives +_broadcasted_scalar_preserving_primitives: Set[core.Primitive] = set() +"""Scalar (broadcasted array) preserving JAX primitives More specifically: if all inputs are (broadcasted) scalars, then the output(s) are broadcasted scalars. Keeping track of broadcasted scalars is allowing proper conversion to ScaledArrays (instead of assigning default scale 1). """ +_broadcasted_zero_preserving_primitives: Set[core.Primitive] = set() +"""Zero (broadcasted array) preserving JAX primitives + +More specifically: if all inputs are (broadcasted) zero scalars, then the output(s) +are broadcasted zero scalars. Keeping track of broadcasted zero scalars is allowing +proper conversion to ScaledArrays (instead of assigning default scale 1). +""" + def _get_lax_prim(scaled_func: Any) -> core.Primitive: try: @@ -172,34 +189,52 @@ class ScalifyTracerArray: Args: array: Normal or scaled array. is_broadcasted_scalar: Is the array a broadcasted scalar (metadata). + is_broadcasted_zero: Is the array a broadcased zero scalar (metadata). """ array: Union[Array, ScaledArray] = None + # Metadata fields. is_broadcasted_scalar: bool = False - - def __init__(self, arr: Union[Array, ScaledArray], is_broadcasted_scalar: Optional[bool] = None) -> None: + is_broadcasted_zero: bool = False + + def __init__( + self, + arr: Union[Array, ScaledArray], + is_broadcasted_scalar: Optional[bool] = None, + is_broadcasted_zero: Optional[bool] = None, + ) -> None: # Convert Python scalars, if necessary. arr = python_scalar_as_numpy(arr) assert isinstance(arr, (np.bool_, np.number, np.ndarray, ScaledArray, *ArrayTypes)) object.__setattr__(self, "array", arr) - # Optional is broadcasted scalar information. + + # Is a zero broadcasted scalar? Only checking when info is not provided. + if is_broadcasted_zero is None: + is_broadcasted_zero = bool(np.all(is_static_zero(self.array))) + object.__setattr__(self, "is_broadcasted_zero", is_broadcasted_zero) + # Optional is broadcasted scalar information (always consistent with broadcasted zero!) is_scalar = self.array.size == 1 is_broadcasted_scalar = is_scalar if is_broadcasted_scalar is None else is_broadcasted_scalar or is_scalar + is_broadcasted_scalar = is_broadcasted_scalar or is_broadcasted_zero object.__setattr__(self, "is_broadcasted_scalar", is_broadcasted_scalar) + # Always make sure we have zero scale to represent broadcasted zero. + if is_broadcasted_zero and isinstance(self.array, ScaledArray): + object.__setattr__(self.array, "scale", np.array(0, self.array.scale.dtype)) + def tree_flatten(self): # See official JAX documentation on extending PyTrees. # Note: using explicit tree flatten instead of chex for MyPy compatibility. children = (self.array,) - aux_data = (self.is_broadcasted_scalar,) + aux_data = (self.is_broadcasted_scalar, self.is_broadcasted_zero) return (children, aux_data) @classmethod def tree_unflatten(cls, aux_data, children): # See official JAX documentation on extending PyTrees. - assert len(aux_data) == 1 + assert len(aux_data) == 2 assert len(children) == 1 - return cls(children[0], aux_data[0]) + return cls(children[0], is_broadcasted_scalar=aux_data[0], is_broadcasted_zero=aux_data[1]) @property def dtype(self) -> DTypeLike: @@ -234,7 +269,12 @@ def to_scaled_array(self, scale_dtype: Optional[DTypeLike] = None) -> ScaledArra if isinstance(self.array, ScaledArray) or not np.issubdtype(self.dtype, np.floating): return self.array - if np.ndim(self.array) == 0: + if self.is_broadcasted_zero: + # Directly create the scaled array. + scale_dtype = scale_dtype or self.array.dtype + scale = np.array(0, dtype=scale_dtype) + return ScaledArray(data=self.array, scale=scale) + elif np.ndim(self.array) == 0: # Single value => "easy case". return as_scaled_array_base(self.array, scale_dtype=scale_dtype) elif self.is_broadcasted_scalar: @@ -341,11 +381,21 @@ def write(var, val: ScalifyTracerArray): any_scaled_inputs = any([v.is_scaled_array for v in invals_tracer]) # Is there a scaled primitive associated? scaled_prim_fn, scaled_prim_type = _scaled_ops_registry.get(eqn.primitive, (None, ScaledPrimitiveType.NEVER)) + # Are outputs broadcasted scalars? + are_inputs_broadcasted_scalars = all([v.is_broadcasted_scalar for v in invals_tracer]) are_outputs_broadcasted_scalars = ( - all([v.is_broadcasted_scalar for v in invals_tracer]) and eqn.primitive in _scalar_preserving_primitives + are_inputs_broadcasted_scalars and eqn.primitive in _broadcasted_scalar_preserving_primitives + ) + # Are outputs broadcasted zeroes? + are_inputs_broadcasted_zeroes = all([v.is_broadcasted_zero for v in invals_tracer]) + are_outputs_broadcasted_zeroes = ( + are_inputs_broadcasted_zeroes and eqn.primitive in _broadcasted_zero_preserving_primitives + ) + # Outputs scalify factory method. + scalify_array_init_fn = lambda v: ScalifyTracerArray( + v, is_broadcasted_scalar=are_outputs_broadcasted_scalars, is_broadcasted_zero=are_outputs_broadcasted_zeroes ) - scalify_array_init_fn = lambda v: ScalifyTracerArray(v, is_broadcasted_scalar=are_outputs_broadcasted_scalars) if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE: # Using normal JAX primitive: no scaled inputs, and not always scale rule. @@ -479,7 +529,7 @@ def scaled_custom_vjp_call_translation(*args: ScalifyTracerArray, **params: Any) # Default collection of scalar preserving JAX primitives. -_scalar_preserving_primitives |= { +_broadcasted_scalar_preserving_primitives |= { lax.abs_p, lax.acos_p, lax.acosh_p, @@ -492,8 +542,12 @@ def scaled_custom_vjp_call_translation(*args: ScalifyTracerArray, **params: Any) lax.bitcast_convert_type_p, lax.broadcast_in_dim_p, lax.cbrt_p, + lax.ceil_p, lax.clamp_p, lax.convert_element_type_p, + lax.exp_p, + lax.expm1_p, + lax.floor_p, lax.integer_pow_p, lax.min_p, lax.max_p, @@ -515,3 +569,35 @@ def scaled_custom_vjp_call_translation(*args: ScalifyTracerArray, **params: Any) lax.tanh_p, lax.transpose_p, } + +# Default collection of zero arrays preserving JAX primitives. +_broadcasted_zero_preserving_primitives |= { + lax.abs_p, + lax.add_p, + lax.broadcast_in_dim_p, + lax.cbrt_p, + lax.ceil_p, + lax.clamp_p, + lax.convert_element_type_p, + lax.exp_p, + lax.expm1_p, + lax.floor_p, + lax.min_p, + lax.max_p, + lax.mul_p, + lax.neg_p, + lax.reduce_prod_p, + lax.reduce_sum_p, + lax.reduce_max_p, + lax.reduce_min_p, + lax.reduce_precision_p, + lax.reshape_p, + lax.slice_p, + lax.sin_p, + lax.sinh_p, + lax.sub_p, + lax.sqrt_p, + lax.tan_p, + lax.tanh_p, + lax.transpose_p, +} diff --git a/jax_scaled_arithmetics/lax/scaled_ops_common.py b/jax_scaled_arithmetics/lax/scaled_ops_common.py index f69b545..e09b137 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops_common.py +++ b/jax_scaled_arithmetics/lax/scaled_ops_common.py @@ -22,6 +22,12 @@ from .base_scaling_primitives import scaled_set_scaling +def _get_data(val: Any) -> Array: + if isinstance(val, ScaledArray): + return val.data + return val + + def check_scalar_scales(*args: ScaledArray): """Check all ScaledArrays have scalar scaling.""" for val in args: @@ -283,8 +289,16 @@ def scaled_log(val: ScaledArray) -> ScaledArray: @core.register_scaled_lax_op def scaled_select_n(which: Array, *cases: ScaledArray) -> ScaledArray: outscale_dtype = promote_types(*[get_scale_dtype(v) for v in cases]) - outscale = np.array(1, dtype=outscale_dtype) - return scaled_op_default_translation(lax.select_n_p, [which, *cases], outscale=outscale) + # Get the max scale for renormalizing. + # TODO: use `get_data_scale` primitive. + scale_cases = [v.scale if isinstance(v, ScaledArray) else np.array(1, outscale_dtype) for v in cases] + scales_arr = jnp.array(scale_cases) + outscale = jnp.max(scales_arr) + # `data` components, renormalized. + data_cases = [_get_data(v) * (s / outscale).astype(v.dtype) for s, v in zip(scale_cases, cases)] + # Apply normal select! + outdata = lax.select_n(which, *data_cases) + return ScaledArray(outdata, outscale) @core.register_scaled_lax_op diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index 609ddbe..c30b369 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -33,6 +33,7 @@ def test__scalify_tracer_array__init__from_python_value(self, arr): assert tracer_arr.array == arr assert not tracer_arr.is_scaled_array assert tracer_arr.is_broadcasted_scalar == (tracer_arr.size == 1) + assert not tracer_arr.is_broadcasted_zero assert tracer_arr.to_array() is tracer_arr.array @parameterized.parameters( @@ -45,32 +46,60 @@ def test__scalify_tracer_array__init__from_normal_array(self, arr): assert tracer_arr.array is arr assert not tracer_arr.is_scaled_array assert tracer_arr.is_broadcasted_scalar == (tracer_arr.size == 1) + assert not tracer_arr.is_broadcasted_zero assert tracer_arr.to_array() is tracer_arr.array # Basic properties. assert tracer_arr.shape == arr.shape assert tracer_arr.size == arr.size + @parameterized.parameters( + {"arr": np.float32(2), "expected_is_zero": False}, + {"arr": np.float32(0), "expected_is_zero": True}, + {"arr": np.array([0, 0]), "expected_is_zero": True}, + {"arr": np.array([0.0, 0.0]), "expected_is_zero": True}, + {"arr": scaled_array([1, 2.0], 0.0, npapi=np), "expected_is_zero": True}, + {"arr": scaled_array([0, 0.0], 1.0, npapi=np), "expected_is_zero": True}, + {"arr": jnp.array([0, 0]), "expected_is_zero": False}, + ) + def test__scalify_tracer_array__init__zero_broadcasted_array(self, arr, expected_is_zero): + tracer_arr = ScalifyTracerArray(arr) + assert tracer_arr.is_broadcasted_zero == expected_is_zero + # Scaled array conversion => scale should be zero. + scaled_arr = tracer_arr.to_scaled_array() + if tracer_arr.is_broadcasted_zero and isinstance(scaled_arr, ScaledArray): + assert scaled_arr.scale == 0 + @parameterized.parameters({"arr": scaled_array([1, 2], 3.0)}) def test__scalify_tracer_array__init__from_scaled_array(self, arr): tracer_arr = ScalifyTracerArray(arr) assert tracer_arr.array is arr assert tracer_arr.is_scaled_array assert tracer_arr.to_scaled_array() is tracer_arr.array + assert not tracer_arr.is_broadcasted_zero def test__scalify_tracer_array__init__is_broadcasted_scalar_kwarg(self): arr = scaled_array([1, 2], 3.0) assert ScalifyTracerArray(arr, is_broadcasted_scalar=True).is_broadcasted_scalar assert not ScalifyTracerArray(arr, is_broadcasted_scalar=False).is_broadcasted_scalar + def test__scalify_tracer_array__init__is_broadcasted_zero_kwarg(self): + arr = scaled_array([0, 1], 3.0) + # NOTE: explicitly passing the argument, not checking the data! + assert ScalifyTracerArray(arr, is_broadcasted_zero=True).is_broadcasted_scalar + assert ScalifyTracerArray(arr, is_broadcasted_zero=True).is_broadcasted_zero + assert not ScalifyTracerArray(arr, is_broadcasted_zero=False).is_broadcasted_scalar + assert not ScalifyTracerArray(arr, is_broadcasted_zero=False).is_broadcasted_zero + def test__scalify_tracer_array__flatten__proper_pytree(self): arr = scaled_array([1, 2], 3.0) - tracer_arr_in = ScalifyTracerArray(arr, True) + tracer_arr_in = ScalifyTracerArray(arr, is_broadcasted_scalar=True, is_broadcasted_zero=True) # Proper round trip! flat_arrays, pytree = jax.tree_util.tree_flatten(tracer_arr_in) tracer_arr_out = jax.tree_util.tree_unflatten(pytree, flat_arrays) assert isinstance(tracer_arr_out, ScalifyTracerArray) assert tracer_arr_out.is_broadcasted_scalar == tracer_arr_in.is_broadcasted_scalar + assert tracer_arr_out.is_broadcasted_zero == tracer_arr_in.is_broadcasted_zero npt.assert_array_equal(np.asarray(tracer_arr_out.array), np.asarray(tracer_arr_in.array)) @parameterized.parameters( diff --git a/tests/lax/test_scaled_ops_common.py b/tests/lax/test_scaled_ops_common.py index 1c5e6b3..70349cc 100644 --- a/tests/lax/test_scaled_ops_common.py +++ b/tests/lax/test_scaled_ops_common.py @@ -185,11 +185,12 @@ def test__scaled_boolean_binary_op__proper_result(self, bool_prim): def test__scaled_select_n__proper_result(self): mask = self.rs.rand(5) > 0.5 lhs = scaled_array(self.rs.rand(5), 2.0, dtype=np.float32) - rhs = scaled_array(self.rs.rand(5), 3.0, dtype=np.float32) + rhs = scaled_array(self.rs.rand(5), 4.0, dtype=np.float32) out = scaled_select_n(mask, lhs, rhs) assert isinstance(out, ScaledArray) assert out.dtype == np.float32 - npt.assert_almost_equal(out.scale, 1) # FIXME! + # Max scale used. + npt.assert_almost_equal(out.scale, 4) npt.assert_array_equal(out, np.where(mask, rhs, lhs)) @parameterized.parameters( @@ -204,6 +205,8 @@ def relu_grad(g): # Gradient with some scale. gin = scaled_array([1.0, 0.5], np.float32(scale), dtype=np.float32) gout = relu_grad(gin) + # print(jax.make_jaxpr(relu_grad)(gin.data)) + # Same scale should be propagated to gradient output. assert isinstance(gout, ScaledArray) npt.assert_array_equal(gout.scale, gin.scale) diff --git a/tests/lax/test_scipy_integration.py b/tests/lax/test_scipy_integration.py index b0d97ae..c3dbc5a 100644 --- a/tests/lax/test_scipy_integration.py +++ b/tests/lax/test_scipy_integration.py @@ -29,9 +29,11 @@ def fn(a): def test__scipy_logsumexp__accurate_scaled_op(self, dtype): from jax.scipy.special import logsumexp - input_scaled = scaled_array(self.rs.rand(10), 2, dtype=dtype) + input_scaled = scaled_array(self.rs.rand(10), 4.0, dtype=dtype) # JAX `logsumexp` Jaxpr is a non-trivial graph! out_scaled = autoscale(logsumexp)(input_scaled) out_expected = logsumexp(np.asarray(input_scaled)) assert out_scaled.dtype == out_expected.dtype + # Proper accuracy + keep the same scale. + npt.assert_array_equal(out_scaled.scale, input_scaled.scale) npt.assert_array_almost_equal(out_scaled, out_expected, decimal=5)