Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scaled translation rules for trivial LAX primitives. #14

Merged
merged 1 commit into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .datatype import ScaledArray, scaled_array # noqa: F401
from .datatype import DTypeLike, ScaledArray, Shape, scaled_array # noqa: F401
from .interpreters import autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401
43 changes: 30 additions & 13 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,56 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

from functools import wraps
from typing import Dict
from typing import Any, Dict

import jax
from jax import core
from jax._src.util import safe_map

from ..core import ScaledArray

_scaled_ops_registry = {}
_scaled_ops_registry: Dict[core.Primitive, Any] = {}


def register_scaled_op(lax_func, scaled_func):
_scaled_ops_registry[lax_func] = scaled_func
def register_scaled_op(prim: core.Primitive, scaled_func: Any) -> None:
"""Register the scaled translation of JAX primitive.

Raises an error if a scaled translation is already existing for this primitive.

def _get_lax_prim(scaled_func):
Args:
prim: JAX primitive.
scaled_fund: Scaled translation of the primitive. With the same interface.
"""
assert isinstance(prim, core.Primitive)
if prim in _scaled_ops_registry:
raise KeyError(f"A scaled translation is already registered for the JAX primitive '{prim}'.")
_scaled_ops_registry[prim] = scaled_func


def _get_lax_prim(scaled_func: Any) -> core.Primitive:
try:
op = getattr(jax.lax, scaled_func.__name__.replace("scaled_", ""))
prim_name = scaled_func.__name__.replace("scaled_", "") + "_p"
prim = getattr(jax.lax, prim_name)
except AttributeError:
raise AttributeError(f"Could not find corresponding jax.lax primitive for {scaled_func.__name__}")
return op
raise AttributeError(f"Could not find corresponding 'jax.lax' primitive for '{scaled_func.__name__}'.")
# Check as well it is a proper primitive! And not something else also in `jax.lax`
if not isinstance(prim, core.Primitive):
raise AttributeError(f"The object `{prim}` is not a proper JAX primitive for '{scaled_func.__name__}'.")
return prim


def register_scaled_lax_op(scaled_func):
"""
Registers a scaled function into the scaled_ops_registry by matching
the function name with pattern `scaled_{func_name}` to a function in the
Registers a scaled function/translation into the scaled_ops_registry by matching
the function name with pattern `scaled_{func_name}` to a primitive in the
`jax.lax` namespace.

Example: `scaled_mul_p` is matched to `jax.lax.mul_p`
Example: `scaled_mul` is matched to `jax.lax.mul_p`
"""
lax_prim = _get_lax_prim(scaled_func)
register_scaled_op(lax_prim, scaled_func)
# Always return the function in the case of decorator use.
return scaled_func


def autoscale(fun):
Expand All @@ -49,7 +66,7 @@ def wrapped(*args, **kwargs):
return wrapped


def autoscale_jaxpr(jaxpr, consts, *args):
def autoscale_jaxpr(jaxpr: core.Jaxpr, consts, *args):
env: Dict[core.Var, ScaledArray] = {}

def read(var):
Expand All @@ -67,7 +84,7 @@ def write(var, val):
invals = safe_map(read, eqn.invars)
if eqn.primitive not in _scaled_ops_registry:
raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet")
outvals = _scaled_ops_registry[eqn.primitive](*invals)
outvals = _scaled_ops_registry[eqn.primitive](*invals, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
safe_map(write, eqn.outvars, outvals)
Expand Down
35 changes: 31 additions & 4 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Optional, Sequence

from jax import lax

from jax_scaled_arithmetics import core
from jax_scaled_arithmetics.core import ScaledArray
from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, Shape


@core.register_scaled_lax_op
def scaled_mul_p(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(A.data * B.data, A.scale * B.scale)
def scaled_broadcast_in_dim(A: ScaledArray, shape: Shape, broadcast_dimensions: Sequence[int]) -> ScaledArray:
return ScaledArray(lax.broadcast_in_dim(A.data, shape=shape, broadcast_dimensions=broadcast_dimensions), A.scale)


__all__ = ["scaled_mul_p"]
@core.register_scaled_lax_op
def scaled_convert_element_type(A: ScaledArray, new_dtype: DTypeLike, weak_type: bool = False) -> ScaledArray:
# NOTE: by default, no rescaling done before casting.
# Choice of adding an optional rescaling op before is up to the user (and which strategy to use).
# NOTE bis: scale not casted as well by default!
return ScaledArray(lax.convert_element_type(A.data, new_dtype=new_dtype), A.scale)


@core.register_scaled_lax_op
def scaled_slice(
A: ScaledArray, start_indices: Sequence[int], limit_indices: Sequence[int], strides: Optional[Sequence[int]] = None
) -> ScaledArray:
return ScaledArray(
lax.slice(A.data, start_indices=start_indices, limit_indices=limit_indices, strides=strides), A.scale
)


@core.register_scaled_lax_op
def scaled_transpose(A: ScaledArray, permutation: Sequence[int]) -> ScaledArray:
return ScaledArray(lax.transpose(A.data, permutation=permutation), A.scale)


@core.register_scaled_lax_op
def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(A.data * B.data, A.scale * B.scale)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ Website = "https://github.com/graphcore-research/jax-scaled-arithmetics/#readme"
[project.optional-dependencies]
test = ["pytest"]

[tool.setuptools]
packages = ["jax_scaled_arithmetics"]

[tool.pytest.ini_options]
minversion = "6.0"
addopts = ["-ra", "--showlocals", "--strict-config", "-p no:hypothesispytest"]
Expand Down
21 changes: 19 additions & 2 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
import numpy as np
import numpy.testing as npt

from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array
from jax_scaled_arithmetics.core import ScaledArray, autoscale, register_scaled_op, scaled_array


class AutoScaleInterpreterTests(chex.TestCase):
def test__register_scaled_op__error_if_already_registered(self):
with self.assertRaises(KeyError):
register_scaled_op(jax.lax.mul_p, lambda a, _: a)

@chex.variants(with_jit=True, without_jit=True)
def test__scaled_identity_function(self):
def func(x):
Expand All @@ -34,7 +38,7 @@ def func(x):
assert jaxpr.outvars[1].aval.shape == ()

@chex.variants(with_jit=True, without_jit=True)
def test__scaled_mul_function(self):
def test__scaled_mul__no_attributes(self):
def func(x, y):
return x * y

Expand All @@ -48,3 +52,16 @@ def func(x, y):
out = asfunc(x, y)
assert isinstance(out, ScaledArray)
npt.assert_array_almost_equal(out, expected)

@chex.variants(with_jit=True, without_jit=True)
def test__scaled_convert_element_type__attributes_passing(self):
def func(x):
return jax.lax.convert_element_type(x, np.float16)

# Autoscale + (optional) jitting.
asfunc = self.variant(autoscale(func))
x = scaled_array([-4.0, 2.0], 0.5, dtype=np.float32)
out = asfunc(x)
assert isinstance(out, ScaledArray)
assert out.dtype == np.float16
npt.assert_array_almost_equal(out, x)
51 changes: 51 additions & 0 deletions tests/lax/test_scaled_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import chex
import numpy as np
import numpy.testing as npt

from jax_scaled_arithmetics.core import ScaledArray, scaled_array
from jax_scaled_arithmetics.lax import (
scaled_broadcast_in_dim,
scaled_convert_element_type,
scaled_mul,
scaled_slice,
scaled_transpose,
)


class ScaledTranslationPrimitivesTests(chex.TestCase):
def test__scaled_broadcast_in_dim__proper_scaling(self):
x = scaled_array(np.random.rand(5), 2, dtype=np.float32)
z = scaled_broadcast_in_dim(x, shape=(5, 1), broadcast_dimensions=(0,))
assert isinstance(z, ScaledArray)
npt.assert_array_equal(z.scale, x.scale)
npt.assert_array_almost_equal(z.data, x.data.reshape((5, 1)))

def test__scaled_convert_element_type__proper_scaling(self):
x = scaled_array(np.random.rand(5), 2, dtype=np.float32)
z = scaled_convert_element_type(x, new_dtype=np.float16)
assert isinstance(z, ScaledArray)
npt.assert_array_equal(z.scale, x.scale)
npt.assert_array_almost_equal(z.data, x.data.astype(z.dtype))

def test__scaled_transpose__proper_scaling(self):
x = scaled_array(np.random.rand(3, 5), 2, dtype=np.float32)
z = scaled_transpose(x, (1, 0))
assert isinstance(z, ScaledArray)
assert z.scale == x.scale
npt.assert_array_almost_equal(z.data, x.data.T)

def test__scaled_slice__proper_scaling(self):
x = scaled_array(np.random.rand(5), 2, dtype=np.float32)
z = scaled_slice(x, (1,), (4,), (2,))
assert isinstance(z, ScaledArray)
assert z.scale == x.scale
npt.assert_array_almost_equal(z.data, x.data[1:4:2])

def test__scaled_mul__proper_scaling(self):
x = scaled_array([-2.0, 2.0], 3, dtype=np.float32)
y = scaled_array([1.5, 1.5], 2, dtype=np.float32)
z = scaled_mul(x, y)
assert isinstance(z, ScaledArray)
assert z.scale == 6
npt.assert_array_almost_equal(z, np.asarray(x) * np.asarray(y))