-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement
set_scaling
and stop_scaling
JAX primitives.
- Loading branch information
Showing
4 changed files
with
107 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from .base_scaling_primitives import set_scaling, set_scaling_p # noqa: F401 | ||
from .scaled_ops import * # noqa: F401, F403 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from typing import Sequence, Union | ||
|
||
import jax | ||
from jax import core | ||
from jax.interpreters import mlir | ||
from jax.interpreters.mlir import LoweringRuleContext, ir | ||
|
||
from jax_scaled_arithmetics.core import ScaledArray, register_scaled_op | ||
|
||
set_scaling_p = core.Primitive("set_scaling_p") | ||
"""`set_scaling` JAX primitive. | ||
In standard JAX, this is just an identity operation, ignoring the `scale` | ||
input, just returning unchanged the `data` component. | ||
In JAX Scaled Arithmetics mode, it will rebalance the data term to | ||
return a ScaledArray semantically equivalent. | ||
""" | ||
|
||
|
||
def set_scaling(data: jax.Array, scale: jax.Array) -> jax.Array: | ||
return set_scaling_p.bind(data, scale) | ||
|
||
|
||
def set_scaling_impl(data: jax.Array, scale: jax.Array) -> jax.Array: | ||
return data | ||
|
||
|
||
def set_scaling_abstract_eval(data: core.ShapedArray, scale: core.ShapedArray) -> core.ShapedArray: | ||
return data | ||
|
||
|
||
def set_default_mlir_lowering( | ||
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]] | ||
) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]: | ||
# Just forwarding `data` term, ignoring the scaling. | ||
return (args[0],) | ||
|
||
|
||
def scaled_set_scaling(val: ScaledArray, scale: ScaledArray) -> ScaledArray: | ||
"""Scaled `set_scaling` implementation: rebalancing the data using the new scale value.""" | ||
assert isinstance(val, ScaledArray) | ||
assert isinstance(scale, ScaledArray) | ||
assert scale.shape == () | ||
|
||
scale_value = scale.to_array() | ||
# Rebalancing data term. | ||
data = val.data * (val.scale / scale_value) | ||
return ScaledArray(data, scale_value) | ||
|
||
|
||
set_scaling_p.multiple_results = False | ||
# Register the abstract evaluation with JAX | ||
set_scaling_p.def_abstract_eval(set_scaling_abstract_eval) | ||
# Register the primal implementation with JAX | ||
set_scaling_p.def_impl(set_scaling_impl) | ||
# Register MLIR lowering. | ||
mlir.register_lowering(set_scaling_p, set_default_mlir_lowering) | ||
register_scaled_op(set_scaling_p, scaled_set_scaling) | ||
|
||
|
||
stop_scaling_p = core.Primitive("stop_scaling_p") | ||
"""`stop_scaling` JAX primitive. | ||
Similar in principle to `jax.lax.stop_gradient` | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
import chex | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import numpy.testing as npt | ||
|
||
from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array | ||
from jax_scaled_arithmetics.lax import set_scaling | ||
|
||
|
||
class SetScalingPrimitiveTests(chex.TestCase): | ||
@chex.variants(with_jit=True, without_jit=True) | ||
def test__set_scaling_primitive__proper_result_without_autoscale(self): | ||
def fn(arr, scale): | ||
return set_scaling(arr, scale) | ||
|
||
fn = self.variant(fn) | ||
arr = jnp.array([2, 3], dtype=np.float32) | ||
scale = jnp.array(4, dtype=np.float32) | ||
out = fn(arr, scale) | ||
npt.assert_array_equal(out, arr) | ||
|
||
@chex.variants(with_jit=True, without_jit=False) | ||
def test__set_scaling_primitive__proper_result_with_autoscale(self): | ||
def fn(arr, scale): | ||
return set_scaling(arr, scale) | ||
|
||
fn = self.variant(autoscale(fn)) | ||
arr = scaled_array([-1.0, 2.0], 1.0, dtype=np.float32) | ||
# TODO: support scalar here! | ||
scale = scaled_array(1.0, 4.0, dtype=np.float32) | ||
out = fn(arr, scale) | ||
# Unchanged output tensor! | ||
assert isinstance(out, ScaledArray) | ||
npt.assert_array_equal(out.scale, scale) | ||
npt.assert_array_equal(out, arr) |