Skip to content

Commit

Permalink
Implement set_scaling and stop_scaling JAX primitives.
Browse files Browse the repository at this point in the history
* `set_scaling`: Set the scaling of a tensor, transforming into `ScaledArray` in `autoscale` mode.
* `stop_scaling`: Stop scale propagation of a tensor, transforming back into a JAX array.

Both operations are no-op identity operations in normal JAX mode.

Note: as pointed by @DouglasOrr, these primitives could also be formalized as casting operations, where ScaledDtypes are
properly defined. To be clarified whether it may be a better setting for the JAX implementation!
  • Loading branch information
balancap committed Nov 17, 2023
1 parent d476f74 commit 07df3b6
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 9 deletions.
4 changes: 3 additions & 1 deletion jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def to_scaled_array(val):

# Primitive is supported by `autoscale`?
if eqn.primitive not in _scaled_ops_registry:
raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet")
raise NotImplementedError(
f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet."
)
outvals = _scaled_ops_registry[eqn.primitive](*invals, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
Expand Down
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/lax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .base_scaling_primitives import set_scaling, set_scaling_p, stop_scaling, stop_scaling_p # noqa: F401
from .scaled_ops import * # noqa: F401, F403
115 changes: 115 additions & 0 deletions jax_scaled_arithmetics/lax/base_scaling_primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Optional, Sequence, Union

import jax
from jax import core
from jax.interpreters import mlir
from jax.interpreters.mlir import LoweringRuleContext, ir

from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, register_scaled_op

set_scaling_p = core.Primitive("set_scaling_p")
"""`set_scaling` JAX primitive.
In standard JAX, this is just an identity operation, ignoring the `scale`
input, just returning unchanged the `data` component.
In JAX Scaled Arithmetics/AutoScale mode, it will rebalance the data term to
return a ScaledArray semantically equivalent.
"""


def set_scaling(values: jax.Array, scale: jax.Array) -> jax.Array:
"""`set_scaling` primitive call method."""
return set_scaling_p.bind(values, scale)


def set_scaling_impl(values: jax.Array, scale: jax.Array) -> jax.Array:
return values


def set_scaling_abstract_eval(values: core.ShapedArray, scale: core.ShapedArray) -> core.ShapedArray:
return values


def set_scaling_mlir_lowering(
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]]
) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
# Just forwarding `values` term, ignoring the `scale`.
return (args[0],)


def scaled_set_scaling(values: ScaledArray, scale: ScaledArray) -> ScaledArray:
"""Scaled `set_scaling` implementation: rebalancing the data using the new scale value."""
assert isinstance(values, ScaledArray)
assert isinstance(scale, ScaledArray)
assert scale.shape == ()
# TODO/FIXME: handle not scaled inputs!!!
scale_value = scale.to_array()
# Rebalancing data tensor using the new scale.
data = values.data * (values.scale / scale_value)
return ScaledArray(data, scale_value)


# Register as standard JAX primitive
set_scaling_p.multiple_results = False
set_scaling_p.def_abstract_eval(set_scaling_abstract_eval)
set_scaling_p.def_impl(set_scaling_impl)
mlir.register_lowering(set_scaling_p, set_scaling_mlir_lowering)
# Register "scaled" translation.
register_scaled_op(set_scaling_p, scaled_set_scaling)


stop_scaling_p = core.Primitive("stop_scaling_p")
"""`stop_scaling` JAX primitive.
In standard JAX, this is just an identity operation (with optional casting).
In JAX Scaled Arithmetics/AutoScale mode, it will return the value tensor,
with optional casting.
Similar in principle to `jax.lax.stop_gradient`
"""


def stop_scaling(values: jax.Array, dtype: Optional[DTypeLike] = None) -> jax.Array:
"""`stop_scaling` primitive call method."""
return stop_scaling_p.bind(values, dtype=dtype)


def stop_scaling_impl(values: jax.Array, dtype: Optional[DTypeLike]) -> jax.Array:
if dtype is not None:
values = values.astype(dtype)
return values


def stop_scaling_abstract_eval(values: core.ShapedArray, dtype: Optional[DTypeLike]) -> core.ShapedArray:
return values.update(dtype=dtype)


def stop_scaling_mlir_lowering(
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params
) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
dtype = params.get("dtype", None)
if dtype is not None:
# TODO: caching of the MLIR lowered function?
stop_scaling_mlir_fn = mlir.lower_fun(lambda x: x.astype(dtype), multiple_results=False)
return stop_scaling_mlir_fn(ctx, *args)
# By default: forward tensor.
return (args[0],)


def scaled_stop_scaling(values: ScaledArray, dtype: Optional[DTypeLike] = None) -> jax.Array:
"""Scaled `stop_scaling` implementation: returning tensor values (with optional cast)."""
assert isinstance(values, ScaledArray)
# TODO/FIXME: how to handle not scaled input.
return values.to_array(dtype=dtype)


# Register as standard JAX primitive
stop_scaling_p.multiple_results = False
stop_scaling_p.def_abstract_eval(stop_scaling_abstract_eval)
stop_scaling_p.def_impl(stop_scaling_impl)
mlir.register_lowering(stop_scaling_p, stop_scaling_mlir_lowering)
# Register "scaled" translation.
register_scaled_op(stop_scaling_p, scaled_stop_scaling)
68 changes: 68 additions & 0 deletions tests/lax/test_base_scaling_primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import chex
import jax
import jax.numpy as jnp
import numpy as np
import numpy.testing as npt

from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array
from jax_scaled_arithmetics.lax import set_scaling, stop_scaling


class SetScalingPrimitiveTests(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
def test__set_scaling_primitive__proper_result_without_autoscale(self):
def fn(arr, scale):
return set_scaling(arr, scale)

fn = self.variant(fn)
arr = jnp.array([2, 3], dtype=np.float32)
scale = jnp.array(4, dtype=np.float32)
out = fn(arr, scale)
npt.assert_array_equal(out, arr)

@chex.variants(with_jit=True, without_jit=True)
def test__set_scaling_primitive__proper_result_with_autoscale(self):
def fn(arr, scale):
return set_scaling(arr, scale)

fn = self.variant(autoscale(fn))
arr = scaled_array([-1.0, 2.0], 1.0, dtype=np.float32)
# TODO: support scalar here!
scale = scaled_array(1.0, 4.0, dtype=np.float32)
out = fn(arr, scale)
# Unchanged output tensor!
assert isinstance(out, ScaledArray)
npt.assert_array_equal(out.scale, scale)
npt.assert_array_equal(out, arr)


class StopScalingPrimitiveTests(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
def test__stop_scaling_primitive__proper_result_without_autoscale(self):
def fn(arr):
# Testing both variants.
return stop_scaling(arr), stop_scaling(arr, dtype=np.float16)

arr = jnp.array([2, 3], dtype=np.float32)
out0, out1 = self.variant(fn)(arr)
assert out0.dtype == arr.dtype
assert out1.dtype == np.float16
npt.assert_array_equal(out0, arr)
npt.assert_array_almost_equal(out1, arr)

@chex.variants(with_jit=True, without_jit=True)
def test__stop_scaling_primitive__proper_result_with_autoscale(self):
def fn(arr):
# Testing both variants.
return stop_scaling(arr), stop_scaling(arr, dtype=np.float16)

fn = self.variant(autoscale(fn))
arr = scaled_array([-1.0, 2.0], 3.0, dtype=np.float32)
out0, out1 = fn(arr)
assert isinstance(out0, jax.Array)
assert isinstance(out1, jax.Array)
assert out0.dtype == arr.dtype
assert out1.dtype == np.float16
npt.assert_array_equal(out0, arr)
npt.assert_array_almost_equal(out1, arr)
21 changes: 13 additions & 8 deletions tests/lax/test_scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,42 @@


class ScaledTranslationPrimitivesTests(chex.TestCase):
def setUp(self):
super().setUp()
# Use random state for reproducibility!
self.rs = np.random.RandomState(42)

def test__scaled_broadcast_in_dim__proper_scaling(self):
x = scaled_array(np.random.rand(5), 2, dtype=np.float32)
x = scaled_array(self.rs.rand(5), 2, dtype=np.float32)
z = scaled_broadcast_in_dim(x, shape=(5, 1), broadcast_dimensions=(0,))
assert isinstance(z, ScaledArray)
npt.assert_array_equal(z.scale, x.scale)
npt.assert_array_almost_equal(z.data, x.data.reshape((5, 1)))

def test__scaled_concatenate__proper_scaling(self):
x = scaled_array(np.random.rand(2, 3), 0.5, dtype=np.float32)
y = scaled_array(np.random.rand(5, 3), 2, dtype=np.float32)
x = scaled_array(self.rs.rand(2, 3), 0.5, dtype=np.float32)
y = scaled_array(self.rs.rand(5, 3), 2, dtype=np.float32)
z = scaled_concatenate([x, y], dimension=0)
assert isinstance(z, ScaledArray)
npt.assert_array_equal(z.scale, y.scale)
npt.assert_array_almost_equal(z, np.concatenate([x, y], axis=0))

def test__scaled_convert_element_type__proper_scaling(self):
x = scaled_array(np.random.rand(5), 2, dtype=np.float32)
x = scaled_array(self.rs.rand(5), 2, dtype=np.float32)
z = scaled_convert_element_type(x, new_dtype=np.float16)
assert isinstance(z, ScaledArray)
npt.assert_array_equal(z.scale, x.scale)
npt.assert_array_almost_equal(z.data, x.data.astype(z.dtype))

def test__scaled_transpose__proper_scaling(self):
x = scaled_array(np.random.rand(3, 5), 2, dtype=np.float32)
x = scaled_array(self.rs.rand(3, 5), 2, dtype=np.float32)
z = scaled_transpose(x, (1, 0))
assert isinstance(z, ScaledArray)
assert z.scale == x.scale
npt.assert_array_almost_equal(z.data, x.data.T)

def test__scaled_slice__proper_scaling(self):
x = scaled_array(np.random.rand(5), 2, dtype=np.float32)
x = scaled_array(self.rs.rand(5), 2, dtype=np.float32)
z = scaled_slice(x, (1,), (4,), (2,))
assert isinstance(z, ScaledArray)
assert z.scale == x.scale
Expand Down Expand Up @@ -81,8 +86,8 @@ def test__scaled_sub__proper_scaling(self):
npt.assert_array_almost_equal(z, np.asarray(x) - np.asarray(y))

def test__scaled_dot_general__proper_scaling(self):
lhs = scaled_array(np.random.rand(3, 5), 2.0, dtype=np.float32)
rhs = scaled_array(np.random.rand(5, 2), 3.0, dtype=np.float32)
lhs = scaled_array(self.rs.rand(3, 5), 2.0, dtype=np.float32)
rhs = scaled_array(self.rs.rand(5, 2), 3.0, dtype=np.float32)
out = scaled_dot_general(lhs, rhs, (((1,), (0,)), ((), ())))
assert isinstance(out, ScaledArray)
assert out.dtype == lhs.dtype
Expand Down

0 comments on commit 07df3b6

Please sign in to comment.