Skip to content

Commit

Permalink
Default scaled translation of JAX primitives.
Browse files Browse the repository at this point in the history
Default scaled translation (i.e. unscale/primitive/rescale) for `log`, `exp` and `select`.
This will unlock minimal operator coverage for MNIST/MLP experiments.
  • Loading branch information
balancap committed Nov 22, 2023
1 parent 9084990 commit 8072914
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 3 deletions.
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .datatype import DTypeLike, ScaledArray, Shape, is_scaled_leaf, scaled_array # noqa: F401
from .datatype import DTypeLike, ScaledArray, Shape, asarray, is_scaled_leaf, scaled_array # noqa: F401
from .interpreters import ( # noqa: F401
ScaledPrimitiveType,
autoscale,
Expand Down
12 changes: 12 additions & 0 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npa
return ScaledArray(data, scale)


def asarray(val: Any, dtype: DTypeLike = None) -> GenericArray:
"""Convert back to a common JAX/Numpy array.
Args:
dtype: Optional dtype of the final array.
"""
if isinstance(val, ScaledArray):
return val.to_array()
# TODO: other special cases?
return val


def is_scaled_leaf(val: Any) -> bool:
"""Is input a JAX PyTree (scaled) leaf, including ScaledArray.
Expand Down
4 changes: 2 additions & 2 deletions jax_scaled_arithmetics/lax/base_scaling_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jax.interpreters import mlir
from jax.interpreters.mlir import LoweringRuleContext, ir

from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, ScaledPrimitiveType, register_scaled_op
from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, ScaledPrimitiveType, asarray, register_scaled_op

set_scaling_p = core.Primitive("set_scaling_p")
"""`set_scaling` JAX primitive.
Expand Down Expand Up @@ -43,7 +43,7 @@ def scaled_set_scaling(values: ScaledArray, scale: ScaledArray) -> ScaledArray:
"""Scaled `set_scaling` implementation: rebalancing the data using the new scale value."""
assert scale.shape == ()
# Automatic promotion should ensure we always get a scaled scalar here!
scale_value = scale.to_array()
scale_value = asarray(scale)
if not isinstance(values, ScaledArray):
# Simple case, with no pre-existing scale.
return ScaledArray(values / scale_value, scale_value)
Expand Down
46 changes: 46 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@
from typing import Any, Optional, Sequence, Tuple

import jax
import jax.core
import jax.numpy as jnp
import numpy as np
from jax import lax

from jax_scaled_arithmetics import core
from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, Shape

from .base_scaling_primitives import scaled_set_scaling


@core.register_scaled_lax_op
def scaled_stop_gradient(val: ScaledArray) -> ScaledArray:
# Stop gradients on both data and scale tensors.
return ScaledArray(lax.stop_gradient(val.data), lax.stop_gradient(val.scale))


@core.register_scaled_lax_op
def scaled_broadcast_in_dim(A: ScaledArray, shape: Shape, broadcast_dimensions: Sequence[int]) -> ScaledArray:
Expand Down Expand Up @@ -195,3 +204,40 @@ def scaled_lt(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array:
@core.register_scaled_lax_op
def scaled_le(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array:
return scaled_boolean_binary_op(lhs, rhs, lax.le_p)


##################################################################
# Default scaled ops implementation #
##################################################################
def scaled_op_default_translation(
prim: jax.core.Primitive, args: Sequence[ScaledArray], outscale: Optional[jax.Array] = None
) -> ScaledArray:
"""Scaled op default translation of a JAX primitive: unscaling inputs + calling normal primitive.
Args:
prim: JAX primitive
args: Input arguments.
outscale: Output scale to use.
"""
inputs = [core.asarray(v) for v in args]
output = prim.bind(*inputs)
# Rescale output, if necessary.
if outscale is None:
return ScaledArray(output, np.array(1.0, dtype=output.dtype))
output_scaled = scaled_set_scaling(output, outscale)
return output_scaled


@core.register_scaled_lax_op
def scaled_exp(val: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.exp_p, [val])


@core.register_scaled_lax_op
def scaled_log(val: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.log_p, [val])


@core.register_scaled_lax_op
def scaled_select_n(which: jax.Array, *cases: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.select_n_p, [which, *cases])
29 changes: 29 additions & 0 deletions tests/lax/test_scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
scaled_concatenate,
scaled_convert_element_type,
scaled_dot_general,
scaled_exp,
scaled_is_finite,
scaled_log,
scaled_mul,
scaled_select_n,
scaled_slice,
scaled_sub,
scaled_transpose,
Expand Down Expand Up @@ -98,6 +101,22 @@ def test__scaled_dot_general__proper_scaling(self):
npt.assert_almost_equal(out.scale, lhs.scale * rhs.scale * np.sqrt(5))
npt.assert_array_almost_equal(out, np.asarray(lhs) @ np.asarray(rhs))

def test__scaled_exp__proper_scaling(self):
val = scaled_array(self.rs.rand(3, 5), 2.0, dtype=np.float32)
out = scaled_exp(val)
assert isinstance(out, ScaledArray)
assert out.dtype == val.dtype
npt.assert_almost_equal(out.scale, 1) # FIXME!
npt.assert_array_almost_equal(out, np.exp(val))

def test__scaled_log__proper_scaling(self):
val = scaled_array(self.rs.rand(3, 5), 2.0, dtype=np.float32)
out = scaled_log(val)
assert isinstance(out, ScaledArray)
assert out.dtype == val.dtype
npt.assert_almost_equal(out.scale, 1) # FIXME!
npt.assert_array_almost_equal(out, np.log(val))


class ScaledTranslationReducePrimitivesTests(chex.TestCase):
def setUp(self):
Expand Down Expand Up @@ -161,3 +180,13 @@ def test__scaled_boolean_binary_op__proper_result(self, bool_prim):
assert out0.dtype == np.bool_
npt.assert_array_equal(out0, bool_prim.bind(lhs.to_array(), rhs.to_array()))
npt.assert_array_equal(out1, bool_prim.bind(lhs.to_array(), lhs.to_array()))

def test__scaled_select_n__proper_result(self):
mask = self.rs.rand(5) > 0.5
lhs = scaled_array(self.rs.rand(5), 2.0, dtype=np.float32)
rhs = scaled_array(self.rs.rand(5), 3.0, dtype=np.float32)
out = scaled_select_n(mask, lhs, rhs)
assert isinstance(out, ScaledArray)
assert out.dtype == np.float32
npt.assert_almost_equal(out.scale, 1) # FIXME!
npt.assert_array_equal(out, np.where(mask, rhs, lhs))

0 comments on commit 8072914

Please sign in to comment.