-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement
set_scaling
and stop_scaling
JAX primitives.
- Loading branch information
Showing
5 changed files
with
204 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from .base_scaling_primitives import set_scaling, set_scaling_p, stop_scaling, stop_scaling_p # noqa: F401 | ||
from .scaled_ops import * # noqa: F401, F403 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from typing import Optional, Sequence, Union | ||
|
||
import jax | ||
from jax import core | ||
from jax.interpreters import mlir | ||
from jax.interpreters.mlir import LoweringRuleContext, ir | ||
|
||
from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, register_scaled_op | ||
|
||
set_scaling_p = core.Primitive("set_scaling_p") | ||
"""`set_scaling` JAX primitive. | ||
In standard JAX, this is just an identity operation, ignoring the `scale` | ||
input, just returning unchanged the `data` component. | ||
In JAX Scaled Arithmetics/AutoScale mode, it will rebalance the data term to | ||
return a ScaledArray semantically equivalent. | ||
""" | ||
|
||
|
||
def set_scaling(values: jax.Array, scale: jax.Array) -> jax.Array: | ||
"""`set_scaling` primitive call method.""" | ||
return set_scaling_p.bind(values, scale) | ||
|
||
|
||
def set_scaling_impl(values: jax.Array, scale: jax.Array) -> jax.Array: | ||
return values | ||
|
||
|
||
def set_scaling_abstract_eval(values: core.ShapedArray, scale: core.ShapedArray) -> core.ShapedArray: | ||
return values | ||
|
||
|
||
def set_scaling_mlir_lowering( | ||
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]] | ||
) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]: | ||
# Just forwarding `values` term, ignoring the `scale`. | ||
return (args[0],) | ||
|
||
|
||
def scaled_set_scaling(values: ScaledArray, scale: ScaledArray) -> ScaledArray: | ||
"""Scaled `set_scaling` implementation: rebalancing the data using the new scale value.""" | ||
assert isinstance(values, ScaledArray) | ||
assert isinstance(scale, ScaledArray) | ||
assert scale.shape == () | ||
# TODO/FIXME: handle not scaled inputs!!! | ||
scale_value = scale.to_array() | ||
# Rebalancing data tensor using the new scale. | ||
data = values.data * (values.scale / scale_value) | ||
return ScaledArray(data, scale_value) | ||
|
||
|
||
# Register as standard JAX primitive | ||
set_scaling_p.multiple_results = False | ||
set_scaling_p.def_abstract_eval(set_scaling_abstract_eval) | ||
set_scaling_p.def_impl(set_scaling_impl) | ||
mlir.register_lowering(set_scaling_p, set_scaling_mlir_lowering) | ||
# Register "scaled" translation. | ||
register_scaled_op(set_scaling_p, scaled_set_scaling) | ||
|
||
|
||
stop_scaling_p = core.Primitive("stop_scaling_p") | ||
"""`stop_scaling` JAX primitive. | ||
In standard JAX, this is just an identity operation (with optional casting). | ||
In JAX Scaled Arithmetics/AutoScale mode, it will return the value tensor, | ||
with optional casting. | ||
Similar in principle to `jax.lax.stop_gradient` | ||
""" | ||
|
||
|
||
def stop_scaling(values: jax.Array, dtype: Optional[DTypeLike] = None) -> jax.Array: | ||
"""`stop_scaling` primitive call method.""" | ||
return stop_scaling_p.bind(values, dtype=dtype) | ||
|
||
|
||
def stop_scaling_impl(values: jax.Array, dtype: Optional[DTypeLike]) -> jax.Array: | ||
if dtype is not None: | ||
values = values.astype(dtype) | ||
return values | ||
|
||
|
||
def stop_scaling_abstract_eval(values: core.ShapedArray, dtype: Optional[DTypeLike]) -> core.ShapedArray: | ||
print(values, dtype) | ||
if dtype is not None: | ||
values = values.update(dtype=dtype) | ||
return values | ||
|
||
|
||
def stop_scaling_mlir_lowering( | ||
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params | ||
) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]: | ||
dtype = params.get("dtype", None) | ||
if dtype is not None: | ||
# TODO: caching of the MLIR lowered function? | ||
stop_scaling_mlir_fn = mlir.lower_fun(lambda x: x.astype(dtype), multiple_results=False) | ||
return stop_scaling_mlir_fn(ctx, *args) | ||
# By default: forward tensor. | ||
return (args[0],) | ||
|
||
|
||
def scaled_stop_scaling(values: ScaledArray, dtype: Optional[DTypeLike] = None) -> jax.Array: | ||
"""Scaled `stop_scaling` implementation: returning tensor values (with optional cast).""" | ||
assert isinstance(values, ScaledArray) | ||
# TODO/FIXME: how to handle not scaled input! | ||
out = values.to_array(dtype=dtype) | ||
return out | ||
|
||
|
||
# Register as standard JAX primitive | ||
stop_scaling_p.multiple_results = False | ||
stop_scaling_p.def_abstract_eval(stop_scaling_abstract_eval) | ||
stop_scaling_p.def_impl(stop_scaling_impl) | ||
mlir.register_lowering(stop_scaling_p, stop_scaling_mlir_lowering) | ||
# Register "scaled" translation. | ||
register_scaled_op(stop_scaling_p, scaled_stop_scaling) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import numpy.testing as npt | ||
|
||
from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array | ||
from jax_scaled_arithmetics.lax import set_scaling, stop_scaling | ||
|
||
|
||
class SetScalingPrimitiveTests(chex.TestCase): | ||
@chex.variants(with_jit=True, without_jit=True) | ||
def test__set_scaling_primitive__proper_result_without_autoscale(self): | ||
def fn(arr, scale): | ||
return set_scaling(arr, scale) | ||
|
||
fn = self.variant(fn) | ||
arr = jnp.array([2, 3], dtype=np.float32) | ||
scale = jnp.array(4, dtype=np.float32) | ||
out = fn(arr, scale) | ||
npt.assert_array_equal(out, arr) | ||
|
||
@chex.variants(with_jit=True, without_jit=True) | ||
def test__set_scaling_primitive__proper_result_with_autoscale(self): | ||
def fn(arr, scale): | ||
return set_scaling(arr, scale) | ||
|
||
fn = self.variant(autoscale(fn)) | ||
arr = scaled_array([-1.0, 2.0], 1.0, dtype=np.float32) | ||
# TODO: support scalar here! | ||
scale = scaled_array(1.0, 4.0, dtype=np.float32) | ||
out = fn(arr, scale) | ||
# Unchanged output tensor! | ||
assert isinstance(out, ScaledArray) | ||
npt.assert_array_equal(out.scale, scale) | ||
npt.assert_array_equal(out, arr) | ||
|
||
|
||
class StopScalingPrimitiveTests(chex.TestCase): | ||
@chex.variants(with_jit=True, without_jit=True) | ||
def test__stop_scaling_primitive__proper_result_without_autoscale(self): | ||
def fn(arr): | ||
# Testing both variants. | ||
return stop_scaling(arr), stop_scaling(arr, dtype=np.float16) | ||
|
||
arr = jnp.array([2, 3], dtype=np.float32) | ||
out0, out1 = self.variant(fn)(arr) | ||
assert out0.dtype == arr.dtype | ||
assert out1.dtype == np.float16 | ||
npt.assert_array_equal(out0, arr) | ||
npt.assert_array_almost_equal(out1, arr) | ||
|
||
@chex.variants(with_jit=True, without_jit=True) | ||
def test__stop_scaling_primitive__proper_result_with_autoscale(self): | ||
def fn(arr): | ||
# Testing both variants. | ||
return stop_scaling(arr), stop_scaling(arr, dtype=np.float16) | ||
|
||
fn = self.variant(autoscale(fn)) | ||
arr = scaled_array([-1.0, 2.0], 3.0, dtype=np.float32) | ||
out0, out1 = fn(arr) | ||
assert isinstance(out0, jax.Array) | ||
assert isinstance(out1, jax.Array) | ||
assert out0.dtype == arr.dtype | ||
assert out1.dtype == np.float16 | ||
npt.assert_array_equal(out0, arr) | ||
npt.assert_array_almost_equal(out1, arr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters