Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX 0.3.x compatibility. Useful for running experiments on IPUs. #39

Merged
merged 1 commit into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,14 @@ jobs:
run: |
pip3 install -e ./
pip3 install -r ./test-requirements.txt
# Run repository unit tests.
- name: Run unit tests
# Run repository unit tests on latest JAX
- name: Run unit tests JAX latest
run: |
pytest --tb=short -v --log-cli-level=INFO ./
- name: JAX 0.3.16 installation
run: |
pip3 install chex==0.1.6 jax==0.3.16 jaxlib==0.3.15 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Run repository unit tests on JAX 0.3
- name: Run unit tests JAX 0.3.16
run: |
pytest --tb=short -v --log-cli-level=INFO ./
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
register_scaled_lax_op,
register_scaled_op,
)
from .typing import get_numpy_api # noqa: F401
from .typing import Array, ArrayTypes, get_numpy_api # noqa: F401
14 changes: 8 additions & 6 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from jax.tree_util import register_pytree_node_class
from numpy.typing import ArrayLike, DTypeLike, NDArray

GenericArray = Union[jax.Array, np.ndarray]
from .typing import Array, ArrayTypes

GenericArray = Union[Array, np.ndarray]


@register_pytree_node_class
Expand Down Expand Up @@ -40,8 +42,8 @@ class ScaledArray:
scale: GenericArray

def __post_init__(self):
assert isinstance(self.data, (jax.Array, np.ndarray))
assert isinstance(self.scale, (jax.Array, np.ndarray, np.number))
assert isinstance(self.data, (*ArrayTypes, np.ndarray))
assert isinstance(self.scale, (*ArrayTypes, np.ndarray, np.number))
# Only supporting scale scalar for now.
assert self.scale.shape == ()

Expand Down Expand Up @@ -94,7 +96,7 @@ def is_scaled_leaf(val: Any) -> bool:
to keep the ScaledArray datastructures (i.e. not flattened as a pair of arrays).
"""
# TODO: check Numpy scalars as well?
return np.isscalar(val) or isinstance(val, (jax.Array, np.ndarray, ScaledArray))
return np.isscalar(val) or isinstance(val, (Array, np.ndarray, ScaledArray))


def scaled_array_base(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npapi: Any = jnp) -> ScaledArray:
Expand Down Expand Up @@ -123,7 +125,7 @@ def as_scaled_array_base(val: Any, scale: Optional[ArrayLike] = None) -> ScaledA
scale = np.array(1, dtype=val.dtype) if scale is None else scale
if isinstance(val, ScaledArray):
return val
elif isinstance(val, (np.ndarray, jax.Array)):
elif isinstance(val, (np.ndarray, Array)):
return ScaledArray(val, scale)
return scaled_array_base(val, scale)

Expand All @@ -146,7 +148,7 @@ def asarray_base(val: Any, dtype: DTypeLike = None) -> GenericArray:
"""Convert back to a common JAX/Numpy array, base function."""
if isinstance(val, ScaledArray):
return val.to_array(dtype=dtype)
elif isinstance(val, (jax.Array, np.ndarray)):
elif isinstance(val, (Array, np.ndarray)):
if dtype is None:
return val
return val.astype(dtype=dtype)
Expand Down
48 changes: 44 additions & 4 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import jax
import numpy as np
from jax import core
from jax._src.custom_derivatives import custom_jvp_call_p, custom_vjp_call_p
from jax._src.pjit import pjit_p
from jax._src.custom_derivatives import custom_jvp_call_jaxpr_p, custom_jvp_call_p, custom_vjp_call_p
from jax._src.util import safe_map

from .datatype import NDArray, ScaledArray, is_scaled_leaf
Expand Down Expand Up @@ -233,21 +232,62 @@ def scaled_pjit_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[Scale
return outputs_scaled_flat


register_scaled_op(pjit_p, scaled_pjit_translation)
try:
from jax._src.pjit import pjit_p

register_scaled_op(pjit_p, scaled_pjit_translation)
except (ImportError, ModuleNotFoundError):
pass


def scaled_xla_call_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[ScaledArray]:
"""Scaled translation of `xla_call`. Basically re-running `autoscale` on sub-jaxpr.

Useful for JAX 0.3 compatibility
"""
jaxpr = kwargs["call_jaxpr"]
name = kwargs["name"]
inline = kwargs["inline"]
keep_unused = kwargs["keep_unused"]
# TODO: properly adapt + pass these options.
# donated_invars = kwargs["donated_invars"]
# in_shardings = kwargs["in_shardings"]
# out_shardings = kwargs["out_shardings"]

assert len(jaxpr.constvars) == 0
# Generate the sub-scaled function, with proper `jax.jit` options.
subfunc = partial(autoscale_jaxpr, jaxpr, [])
subfunc.__name__ = name # type:ignore
subfunc = jax.jit(subfunc, inline=inline, keep_unused=keep_unused)

outputs_scaled_flat = subfunc(*args)
return outputs_scaled_flat


try:
from jax.interpreters.xla import xla_call_p

register_scaled_op(xla_call_p, scaled_xla_call_translation)
except (ImportError, ModuleNotFoundError):
pass


def scaled_custom_jvp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]:
"""Scaled translation of `custom_jvp_call` primitive. Forwarding the scaled call to sub-jaxpr,
and modifying the underlying `jvp` function.
"""
# [fun, jvp], bind_params = custom_jvp_call_p.get_bind_params(params)
call_closed_jaxpr = params["call_jaxpr"]
key_jaxpr = "call_jaxpr" if jax.__version_info__[1] > 3 else "fun_jaxpr"
call_closed_jaxpr = params[key_jaxpr]
# JAX 0.3 compatibility.
assert params.get("num_consts", 0) == 0
# FIXME: re-call the custom_jvp decorator/bind.
call_subfunc = partial(autoscale_jaxpr, call_closed_jaxpr.jaxpr, call_closed_jaxpr.literals)
return call_subfunc(*args)


register_scaled_op(custom_jvp_call_p, scaled_custom_jvp_call_translation)
register_scaled_op(custom_jvp_call_jaxpr_p, scaled_custom_jvp_call_translation)


def scaled_custom_vjp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]:
Expand Down
10 changes: 10 additions & 0 deletions jax_scaled_arithmetics/core/typing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any

# import chex
import jax
import jax.numpy as jnp
import jaxlib
import numpy as np

# Type aliasing. To be compatible with JAX 0.3 as well.
if jax.__version_info__[1] > 3:
Array = jax.Array
ArrayTypes = (jax.Array,)
else:
Array = jaxlib.xla_extension.DeviceArray
ArrayTypes = (jaxlib.xla_extension.DeviceArray, jax.interpreters.partial_eval.DynamicJaxprTracer) # type:ignore


def get_numpy_api(val: Any) -> Any:
"""Get the Numpy API corresponding to an array.
Expand Down
13 changes: 6 additions & 7 deletions jax_scaled_arithmetics/lax/base_scaling_primitives.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Optional, Sequence, Union

import jax
from jax import core
from jax.interpreters import mlir
from jax.interpreters.mlir import LoweringRuleContext, ir

from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, ScaledPrimitiveType, asarray, register_scaled_op
from jax_scaled_arithmetics.core import Array, DTypeLike, ScaledArray, ScaledPrimitiveType, asarray, register_scaled_op

set_scaling_p = core.Primitive("set_scaling_p")
"""`set_scaling` JAX primitive.
Expand All @@ -19,12 +18,12 @@
"""


def set_scaling(values: jax.Array, scale: jax.Array) -> jax.Array:
def set_scaling(values: Array, scale: Array) -> Array:
"""`set_scaling` primitive call method."""
return set_scaling_p.bind(values, scale)


def set_scaling_impl(values: jax.Array, scale: jax.Array) -> jax.Array:
def set_scaling_impl(values: Array, scale: Array) -> Array:
return values


Expand Down Expand Up @@ -73,12 +72,12 @@ def scaled_set_scaling(values: ScaledArray, scale: ScaledArray) -> ScaledArray:
"""


def stop_scaling(values: jax.Array, dtype: Optional[DTypeLike] = None) -> jax.Array:
def stop_scaling(values: Array, dtype: Optional[DTypeLike] = None) -> Array:
"""`stop_scaling` primitive call method."""
return stop_scaling_p.bind(values, dtype=dtype)


def stop_scaling_impl(values: jax.Array, dtype: Optional[DTypeLike]) -> jax.Array:
def stop_scaling_impl(values: Array, dtype: Optional[DTypeLike]) -> Array:
if dtype is not None:
values = values.astype(dtype)
return values
Expand All @@ -100,7 +99,7 @@ def stop_scaling_mlir_lowering(
return (args[0],)


def scaled_stop_scaling(values: ScaledArray, dtype: Optional[DTypeLike] = None) -> jax.Array:
def scaled_stop_scaling(values: ScaledArray, dtype: Optional[DTypeLike] = None) -> Array:
"""Scaled `stop_scaling` implementation: returning tensor values (with optional cast)."""
assert isinstance(values, ScaledArray)
# TODO/FIXME: how to handle not scaled input?
Expand Down
22 changes: 11 additions & 11 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from jax._src.ad_util import add_any_p

from jax_scaled_arithmetics import core
from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, Shape, as_scaled_array, register_scaled_op
from jax_scaled_arithmetics.core import Array, DTypeLike, ScaledArray, Shape, as_scaled_array, register_scaled_op

from .base_scaling_primitives import scaled_set_scaling

Expand Down Expand Up @@ -204,7 +204,7 @@ def scaled_reduce_min(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:


@core.register_scaled_lax_op
def scaled_is_finite(val: ScaledArray) -> jax.Array:
def scaled_is_finite(val: ScaledArray) -> Array:
assert isinstance(val, ScaledArray)
if np.issubdtype(val.scale.dtype, np.integer):
# Integer scale case => only check the data component.
Expand All @@ -213,7 +213,7 @@ def scaled_is_finite(val: ScaledArray) -> jax.Array:
return lax.and_p.bind(lax.is_finite(val.data), lax.is_finite(val.scale))


def scaled_boolean_binary_op(lhs: ScaledArray, rhs: ScaledArray, prim: jax.core.Primitive) -> jax.Array:
def scaled_boolean_binary_op(lhs: ScaledArray, rhs: ScaledArray, prim: jax.core.Primitive) -> Array:
"""Generic implementation of any boolean binary operation."""
assert isinstance(lhs, ScaledArray)
assert isinstance(rhs, ScaledArray)
Expand All @@ -223,40 +223,40 @@ def scaled_boolean_binary_op(lhs: ScaledArray, rhs: ScaledArray, prim: jax.core.


@core.register_scaled_lax_op
def scaled_eq(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array:
def scaled_eq(lhs: ScaledArray, rhs: ScaledArray) -> Array:
return scaled_boolean_binary_op(lhs, rhs, lax.eq_p)


@core.register_scaled_lax_op
def scaled_ne(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array:
def scaled_ne(lhs: ScaledArray, rhs: ScaledArray) -> Array:
return scaled_boolean_binary_op(lhs, rhs, lax.ne_p)


@core.register_scaled_lax_op
def scaled_gt(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array:
def scaled_gt(lhs: ScaledArray, rhs: ScaledArray) -> Array:
return scaled_boolean_binary_op(lhs, rhs, lax.gt_p)


@core.register_scaled_lax_op
def scaled_ge(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array:
def scaled_ge(lhs: ScaledArray, rhs: ScaledArray) -> Array:
return scaled_boolean_binary_op(lhs, rhs, lax.ge_p)


@core.register_scaled_lax_op
def scaled_lt(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array:
def scaled_lt(lhs: ScaledArray, rhs: ScaledArray) -> Array:
return scaled_boolean_binary_op(lhs, rhs, lax.lt_p)


@core.register_scaled_lax_op
def scaled_le(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array:
def scaled_le(lhs: ScaledArray, rhs: ScaledArray) -> Array:
return scaled_boolean_binary_op(lhs, rhs, lax.le_p)


##################################################################
# Default scaled ops implementation #
##################################################################
def scaled_op_default_translation(
prim: jax.core.Primitive, args: Sequence[ScaledArray], outscale: Optional[jax.Array] = None
prim: jax.core.Primitive, args: Sequence[ScaledArray], outscale: Optional[Array] = None
) -> ScaledArray:
"""Scaled op default translation of a JAX primitive: unscaling inputs + calling normal primitive.

Expand Down Expand Up @@ -285,7 +285,7 @@ def scaled_log(val: ScaledArray) -> ScaledArray:


@core.register_scaled_lax_op
def scaled_select_n(which: jax.Array, *cases: ScaledArray) -> ScaledArray:
def scaled_select_n(which: Array, *cases: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.select_n_p, [which, *cases])


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
]
dependencies = [
"chex",
"chex >= 0.1.6",
"jax >= 0.3.16",
"jaxlib >= 0.3.15",
"numpy >= 1.22.4"
Expand Down
5 changes: 2 additions & 3 deletions tests/core/test_datatype.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import chex
import jax
import jax.numpy as jnp
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized
from jax.core import ShapedArray

from jax_scaled_arithmetics.core import ScaledArray, as_scaled_array, asarray, is_scaled_leaf, scaled_array
from jax_scaled_arithmetics.core import Array, ScaledArray, as_scaled_array, asarray, is_scaled_leaf, scaled_array


class ScaledArrayDataclassTests(chex.TestCase):
Expand Down Expand Up @@ -135,6 +134,6 @@ def test__asarray__complex_pytree(self):
output = asarray(input)
assert isinstance(output, dict)
assert len(output) == 2
assert all([isinstance(v, jax.Array) for v in output.values()])
assert all([isinstance(v, Array) for v in output.values()])
npt.assert_array_almost_equal(output["x"], input["x"])
npt.assert_array_almost_equal(output["y"], input["y"])
7 changes: 4 additions & 3 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from absl.testing import parameterized

from jax_scaled_arithmetics.core import (
Array,
ScaledArray,
asarray,
autoscale,
Expand All @@ -32,7 +33,7 @@ def func(x):
data = np.array([1, 2], dtype=np.float32)
out = func(data)
# Proper behaviour!
assert isinstance(out, jax.Array)
assert isinstance(out, Array)
npt.assert_array_equal(out, [2, 4])
# Check jaxpr.
jaxpr = jax.make_jaxpr(func)(data).jaxpr
Expand Down Expand Up @@ -65,7 +66,7 @@ def myfunc(x):
# One main jit equation.
assert len(jaxpr.eqns) == 1
eqn = jaxpr.eqns[0]
assert eqn.primitive.name == "pjit"
assert eqn.primitive.name in ("pjit", "xla_call")
assert eqn.params["name"] == "myfunc"
# TODO: other parameters.
# Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray
Expand Down Expand Up @@ -125,7 +126,7 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn,
expected_output = self.variant(fn)(*raw_inputs)

# Do we re-construct properly the output type (i.e. handling Pytree properly)?
if not isinstance(expected_output, (np.ndarray, jax.Array)):
if not isinstance(expected_output, (np.ndarray, Array)):
assert type(scaled_output) is type(expected_output)

# Check each output in the flatten tree.
Expand Down
7 changes: 3 additions & 4 deletions tests/lax/test_base_scaling_primitives.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import chex
import jax
import jax.numpy as jnp
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array
from jax_scaled_arithmetics.core import Array, ScaledArray, autoscale, scaled_array
from jax_scaled_arithmetics.lax import set_scaling, stop_scaling


Expand Down Expand Up @@ -64,8 +63,8 @@ def fn(arr):
fn = self.variant(autoscale(fn))
arr = scaled_array([-1.0, 2.0], 3.0, dtype=np.float32)
out0, out1 = fn(arr)
assert isinstance(out0, jax.Array)
assert isinstance(out1, jax.Array)
assert isinstance(out0, Array)
assert isinstance(out1, Array)
assert out0.dtype == arr.dtype
assert out1.dtype == np.float16
npt.assert_array_equal(out0, arr)
Expand Down
Loading