diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scaled_arithmetics/core/__init__.py index 713ea60..c42c7ea 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scaled_arithmetics/core/__init__.py @@ -7,3 +7,4 @@ register_scaled_lax_op, register_scaled_op, ) +from .typing import get_numpy_api # noqa: F401 diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index aea3e25..d4ab441 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -55,8 +55,13 @@ def promote_scalar_to_scaled_array(val: Any) -> ScaledArray: Note: needs to work with any input type, including JAX tracer ones. """ - assert val.shape == () + # int / float special cases + if isinstance(val, float): + return ScaledArray(data=np.array(1, dtype=np.float32), scale=np.float32(val)) + elif isinstance(val, int): + return ScaledArray(data=np.array(1, dtype=np.int32), scale=np.int32(val)) # Just a Numpy constant for data => can be optimized out in XLA compiler. + assert val.shape == () onedata = np.array(1, dtype=val.dtype) return ScaledArray(data=onedata, scale=val) @@ -67,7 +72,7 @@ def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray: Only supporting Numpy scalars at the moment. """ # TODO: generalized rules! - assert val.shape == () + assert np.ndim(val) == 0 assert np.issubdtype(val.dtype, np.floating) return ScaledArray(data=np.array(1.0, dtype=val.dtype), scale=np.copy(val)) @@ -164,7 +169,7 @@ def write(var, val): def promote_to_scaled_array(val): if isinstance(val, ScaledArray): return val - elif val.shape == (): + elif np.ndim(val) == 0: return promote_scalar_to_scaled_array(val) # No promotion rule => just return as such. return val diff --git a/jax_scaled_arithmetics/core/typing.py b/jax_scaled_arithmetics/core/typing.py index 9dc9fcb..bad0a6e 100644 --- a/jax_scaled_arithmetics/core/typing.py +++ b/jax_scaled_arithmetics/core/typing.py @@ -1 +1,18 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from typing import Any + +import jax +import jax.numpy as jnp +import numpy as np + + +def get_numpy_api(val: Any) -> Any: + """Get the Numpy API corresponding to an array. + + JAX or classic Numpy supported. + """ + if isinstance(val, jax.Array): + return jnp + elif isinstance(val, (np.ndarray, np.number)): + return np + raise NotImplementedError(f"Unsupported input type '{type(val)}'. No matching Numpy API.") diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index e048955..af2e4c5 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -13,6 +13,28 @@ from .base_scaling_primitives import scaled_set_scaling +def check_scalar_scales(*args: ScaledArray): + """Check all ScaledArrays have scalar scaling.""" + for val in args: + assert np.ndim(val.scale) == 0 + + +def promote_scale_types(*args: ScaledArray) -> Sequence[ScaledArray]: + """Promote scale datatypes to a common one. + + Note: we are using JAX Numpy promotion, to avoid 64bits types by default. + """ + if len(args) == 1: + return args + # Find a common scale datatype. + scale_dtype = args[0].scale.dtype + for val in args[1:]: + scale_dtype = jnp.promote_types(scale_dtype, val.scale.dtype) + + outputs = [ScaledArray(v.data, v.scale.astype(scale_dtype)) for v in args] + return outputs + + @core.register_scaled_lax_op def scaled_stop_gradient(val: ScaledArray) -> ScaledArray: # Stop gradients on both data and scale tensors. @@ -65,9 +87,9 @@ def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray: @core.register_scaled_lax_op def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray: - # Only supporting floating scale right now. - assert A.scale.dtype == B.scale.dtype - assert np.issubdtype(A.scale, np.floating) + check_scalar_scales(A, B) + A, B = promote_scale_types(A, B) + assert np.issubdtype(A.scale.dtype, np.floating) # TODO: what happens to `sqrt` for non-floating scale? output_scale = lax.sqrt(A.scale**2 + B.scale**2) # check correct type output if mismatch between data and scale precision @@ -77,9 +99,10 @@ def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray: @core.register_scaled_lax_op def scaled_sub(A: ScaledArray, B: ScaledArray) -> ScaledArray: + check_scalar_scales(A, B) + A, B = promote_scale_types(A, B) # Only supporting floating scale right now. - assert A.scale.dtype == B.scale.dtype - assert np.issubdtype(A.scale, np.floating) + assert np.issubdtype(A.scale.dtype, np.floating) # TODO: what happens to `sqrt` for non-floating scale? output_scale = lax.sqrt(A.scale**2 + B.scale**2) # check correct type output if mismatch between data and scale precision diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index ab83d9d..c65bc6d 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -122,9 +122,12 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn, @parameterized.parameters( {"input": np.array(3)}, {"input": jnp.array(3)}, + {"input": 3}, + {"input": 3.0}, ) def test__promote_scalar_to_scaled_array__proper_output(self, input): scaled_val = promote_scalar_to_scaled_array(input) assert isinstance(scaled_val, ScaledArray) + assert scaled_val.data.dtype == scaled_val.scale.dtype npt.assert_array_equal(scaled_val.data, 1) npt.assert_array_equal(scaled_val.scale, input) diff --git a/tests/lax/test_scipy_integration.py b/tests/lax/test_scipy_integration.py new file mode 100644 index 0000000..66d0c3d --- /dev/null +++ b/tests/lax/test_scipy_integration.py @@ -0,0 +1,22 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import chex +import numpy as np +import numpy.testing as npt + +from jax_scaled_arithmetics.core import autoscale, scaled_array + + +class ScaledTranslationPrimitivesTests(chex.TestCase): + def setUp(self): + super().setUp() + # Use random state for reproducibility! + self.rs = np.random.RandomState(42) + + def test__scipy_logsumexp__accurate_scaled_op(self): + from jax.scipy.special import logsumexp + + input_scaled = scaled_array(self.rs.rand(10), 2, dtype=np.float32) + # JAX `logsumexp` Jaxpr is a non-trivial graph! + out_scaled = autoscale(logsumexp)(input_scaled) + out_expected = logsumexp(np.asarray(input_scaled)) + npt.assert_array_almost_equal(out_scaled, out_expected, decimal=5)