Skip to content

Commit

Permalink
Implement set_scaling and stop_scaling JAX primitives.
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Nov 16, 2023
1 parent d476f74 commit 65bb852
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 1 deletion.
4 changes: 3 additions & 1 deletion jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def to_scaled_array(val):

# Primitive is supported by `autoscale`?
if eqn.primitive not in _scaled_ops_registry:
raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet")
raise NotImplementedError(
f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet."
)
outvals = _scaled_ops_registry[eqn.primitive](*invals, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
Expand Down
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/lax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .base_scaling_primitives import set_scaling, set_scaling_p # noqa: F401
from .scaled_ops import * # noqa: F401, F403
67 changes: 67 additions & 0 deletions jax_scaled_arithmetics/lax/base_scaling_primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Sequence, Union

import jax
from jax import core
from jax.interpreters import mlir
from jax.interpreters.mlir import LoweringRuleContext, ir

from jax_scaled_arithmetics.core import ScaledArray, register_scaled_op

set_scaling_p = core.Primitive("set_scaling_p")
"""`set_scaling` JAX primitive.
In standard JAX, this is just an identity operation, ignoring the `scale`
input, just returning unchanged the `data` component.
In JAX Scaled Arithmetics mode, it will rebalance the data term to
return a ScaledArray semantically equivalent.
"""


def set_scaling(data: jax.Array, scale: jax.Array) -> jax.Array:
return set_scaling_p.bind(data, scale)


def set_scaling_impl(data: jax.Array, scale: jax.Array) -> jax.Array:
return data


def set_scaling_abstract_eval(data: core.ShapedArray, scale: core.ShapedArray) -> core.ShapedArray:
return data


def set_default_mlir_lowering(
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]]
) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
# Just forwarding `data` term, ignoring the scaling.
return (args[0],)


def scaled_set_scaling(val: ScaledArray, scale: ScaledArray) -> ScaledArray:
"""Scaled `set_scaling` implementation: rebalancing the data using the new scale value."""
assert isinstance(val, ScaledArray)
assert isinstance(scale, ScaledArray)
assert scale.shape == ()

scale_value = scale.to_array()
# Rebalancing data term.
data = val.data * (val.scale / scale_value)
return ScaledArray(data, scale_value)


set_scaling_p.multiple_results = False
# Register the abstract evaluation with JAX
set_scaling_p.def_abstract_eval(set_scaling_abstract_eval)
# Register the primal implementation with JAX
set_scaling_p.def_impl(set_scaling_impl)
# Register MLIR lowering.
mlir.register_lowering(set_scaling_p, set_default_mlir_lowering)
register_scaled_op(set_scaling_p, scaled_set_scaling)


stop_scaling_p = core.Primitive("stop_scaling_p")
"""`stop_scaling` JAX primitive.
Similar in principle to `jax.lax.stop_gradient`
"""
36 changes: 36 additions & 0 deletions tests/lax/test_base_scaling_primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import chex
import jax.numpy as jnp
import numpy as np
import numpy.testing as npt

from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array
from jax_scaled_arithmetics.lax import set_scaling


class SetScalingPrimitiveTests(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
def test__set_scaling_primitive__proper_result_without_autoscale(self):
def fn(arr, scale):
return set_scaling(arr, scale)

fn = self.variant(fn)
arr = jnp.array([2, 3], dtype=np.float32)
scale = jnp.array(4, dtype=np.float32)
out = fn(arr, scale)
npt.assert_array_equal(out, arr)

@chex.variants(with_jit=True, without_jit=False)
def test__set_scaling_primitive__proper_result_with_autoscale(self):
def fn(arr, scale):
return set_scaling(arr, scale)

fn = self.variant(autoscale(fn))
arr = scaled_array([-1.0, 2.0], 1.0, dtype=np.float32)
# TODO: support scalar here!
scale = scaled_array(1.0, 4.0, dtype=np.float32)
out = fn(arr, scale)
# Unchanged output tensor!
assert isinstance(out, ScaledArray)
npt.assert_array_equal(out.scale, scale)
npt.assert_array_equal(out, arr)

0 comments on commit 65bb852

Please sign in to comment.