Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration tests on high-level SciPy functions. #30

Merged
merged 1 commit into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
register_scaled_lax_op,
register_scaled_op,
)
from .typing import get_numpy_api # noqa: F401
11 changes: 8 additions & 3 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,13 @@ def promote_scalar_to_scaled_array(val: Any) -> ScaledArray:

Note: needs to work with any input type, including JAX tracer ones.
"""
assert val.shape == ()
# int / float special cases
if isinstance(val, float):
return ScaledArray(data=np.array(1, dtype=np.float32), scale=np.float32(val))
elif isinstance(val, int):
return ScaledArray(data=np.array(1, dtype=np.int32), scale=np.int32(val))
# Just a Numpy constant for data => can be optimized out in XLA compiler.
assert val.shape == ()
onedata = np.array(1, dtype=val.dtype)
return ScaledArray(data=onedata, scale=val)

Expand All @@ -67,7 +72,7 @@ def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray:
Only supporting Numpy scalars at the moment.
"""
# TODO: generalized rules!
assert val.shape == ()
assert np.ndim(val) == 0
assert np.issubdtype(val.dtype, np.floating)
return ScaledArray(data=np.array(1.0, dtype=val.dtype), scale=np.copy(val))

Expand Down Expand Up @@ -164,7 +169,7 @@ def write(var, val):
def promote_to_scaled_array(val):
if isinstance(val, ScaledArray):
return val
elif val.shape == ():
elif np.ndim(val) == 0:
return promote_scalar_to_scaled_array(val)
# No promotion rule => just return as such.
return val
Expand Down
17 changes: 17 additions & 0 deletions jax_scaled_arithmetics/core/typing.py
Original file line number Diff line number Diff line change
@@ -1 +1,18 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np


def get_numpy_api(val: Any) -> Any:
"""Get the Numpy API corresponding to an array.

JAX or classic Numpy supported.
"""
if isinstance(val, jax.Array):
return jnp
elif isinstance(val, (np.ndarray, np.number)):
return np
raise NotImplementedError(f"Unsupported input type '{type(val)}'. No matching Numpy API.")
33 changes: 28 additions & 5 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,28 @@
from .base_scaling_primitives import scaled_set_scaling


def check_scalar_scales(*args: ScaledArray):
"""Check all ScaledArrays have scalar scaling."""
for val in args:
assert np.ndim(val.scale) == 0


def promote_scale_types(*args: ScaledArray) -> Sequence[ScaledArray]:
"""Promote scale datatypes to a common one.

Note: we are using JAX Numpy promotion, to avoid 64bits types by default.
"""
if len(args) == 1:
return args
# Find a common scale datatype.
scale_dtype = args[0].scale.dtype
for val in args[1:]:
scale_dtype = jnp.promote_types(scale_dtype, val.scale.dtype)

outputs = [ScaledArray(v.data, v.scale.astype(scale_dtype)) for v in args]
return outputs


@core.register_scaled_lax_op
def scaled_stop_gradient(val: ScaledArray) -> ScaledArray:
# Stop gradients on both data and scale tensors.
Expand Down Expand Up @@ -65,9 +87,9 @@ def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray:

@core.register_scaled_lax_op
def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray:
# Only supporting floating scale right now.
assert A.scale.dtype == B.scale.dtype
assert np.issubdtype(A.scale, np.floating)
check_scalar_scales(A, B)
A, B = promote_scale_types(A, B)
assert np.issubdtype(A.scale.dtype, np.floating)
# TODO: what happens to `sqrt` for non-floating scale?
output_scale = lax.sqrt(A.scale**2 + B.scale**2)
# check correct type output if mismatch between data and scale precision
Expand All @@ -77,9 +99,10 @@ def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray:

@core.register_scaled_lax_op
def scaled_sub(A: ScaledArray, B: ScaledArray) -> ScaledArray:
check_scalar_scales(A, B)
A, B = promote_scale_types(A, B)
# Only supporting floating scale right now.
assert A.scale.dtype == B.scale.dtype
assert np.issubdtype(A.scale, np.floating)
assert np.issubdtype(A.scale.dtype, np.floating)
# TODO: what happens to `sqrt` for non-floating scale?
output_scale = lax.sqrt(A.scale**2 + B.scale**2)
# check correct type output if mismatch between data and scale precision
Expand Down
3 changes: 3 additions & 0 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,12 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn,
@parameterized.parameters(
{"input": np.array(3)},
{"input": jnp.array(3)},
{"input": 3},
{"input": 3.0},
)
def test__promote_scalar_to_scaled_array__proper_output(self, input):
scaled_val = promote_scalar_to_scaled_array(input)
assert isinstance(scaled_val, ScaledArray)
assert scaled_val.data.dtype == scaled_val.scale.dtype
npt.assert_array_equal(scaled_val.data, 1)
npt.assert_array_equal(scaled_val.scale, input)
22 changes: 22 additions & 0 deletions tests/lax/test_scipy_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import chex
import numpy as np
import numpy.testing as npt

from jax_scaled_arithmetics.core import autoscale, scaled_array


class ScaledTranslationPrimitivesTests(chex.TestCase):
def setUp(self):
super().setUp()
# Use random state for reproducibility!
self.rs = np.random.RandomState(42)

def test__scipy_logsumexp__accurate_scaled_op(self):
from jax.scipy.special import logsumexp

input_scaled = scaled_array(self.rs.rand(10), 2, dtype=np.float32)
# JAX `logsumexp` Jaxpr is a non-trivial graph!
out_scaled = autoscale(logsumexp)(input_scaled)
out_expected = logsumexp(np.asarray(input_scaled))
npt.assert_array_almost_equal(out_scaled, out_expected, decimal=5)