Skip to content

Commit

Permalink
Implement reduce_sum/prod/max/min scaled translation rules.
Browse files Browse the repository at this point in the history
TODO!
  • Loading branch information
balancap committed Nov 21, 2023
1 parent 642a58e commit 896c8d9
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 20 deletions.
8 changes: 7 additions & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .datatype import DTypeLike, ScaledArray, Shape, is_scaled_leaf, scaled_array # noqa: F401
from .interpreters import ScaledPrimitiveType, autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401
from .interpreters import ( # noqa: F401
ScaledPrimitiveType,
autoscale,
find_registered_scaled_op,
register_scaled_lax_op,
register_scaled_op,
)
47 changes: 29 additions & 18 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from enum import IntEnum
from functools import wraps
from typing import Any, Dict
from typing import Any, Dict, Tuple

import jax
import numpy as np
Expand All @@ -10,7 +10,24 @@

from .datatype import NDArray, ScaledArray

_scaled_ops_registry: Dict[core.Primitive, Any] = {}

class ScaledPrimitiveType(IntEnum):
"""Scale (JAX) primitive type.
This enum described the behaviour when `autoscale` is
tracing the graph.
FORWARD: Forwarding scaling => only used if scaled inputs.
Default behaviour.
ALWAYS_SCALE: Always use scaled version.
"""

NEVER = 0
FORWARD = 1
ALWAYS_SCALE = 2


_scaled_ops_registry: Dict[core.Primitive, Tuple[Any, ScaledPrimitiveType]] = {}


def _get_lax_prim(scaled_func: Any) -> core.Primitive:
Expand Down Expand Up @@ -43,22 +60,6 @@ def promote_scalar_to_scaled_array(val: Any) -> ScaledArray:
return ScaledArray(data=onedata, scale=val)


class ScaledPrimitiveType(IntEnum):
"""Scale (JAX) primitive type.
This enum described the behaviour when `autoscale` is
tracing the graph.
FORWARD: Forwarding scaling => only used if scaled inputs.
Default behaviour.
ALWAYS_SCALE: Always use scaled version.
"""

NEVER = 0
FORWARD = 1
ALWAYS_SCALE = 2


def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray:
"""Get the ScaledArray corresponding to a Numpy constant.
Expand Down Expand Up @@ -102,6 +103,16 @@ def register_scaled_lax_op(scaled_func):
return scaled_func


def find_registered_scaled_op(prim: core.Primitive) -> Tuple[Any, ScaledPrimitiveType]:
"""Find a registered JAX scaled operation/translation. Returns (None, None) if
the primitive does not have a scaled translation registered.
Args:
prim: JAX primitive.
"""
return _scaled_ops_registry.get(prim, (None, ScaledPrimitiveType.NEVER))


def autoscale(fun):
"""`autoscale` JAX graph transformation.
Expand Down
38 changes: 38 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,41 @@ def scaled_dot_general(
/ contracting_rescale
)
return ScaledArray(output_data, output_scale)


@core.register_scaled_lax_op
def scaled_reduce_sum(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
shape = val.shape
axes_size = np.array([shape[idx] for idx in axes])
# Rescale data component following reduction axes.
axes_rescale = np.sqrt(np.prod(axes_size))
data = lax.reduce_sum_p.bind(val.data, axes=axes) / axes_rescale
outscale = val.scale * axes_rescale
return ScaledArray(data, outscale)


@core.register_scaled_lax_op
def scaled_reduce_prod(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
shape = val.shape
data = lax.reduce_prod_p.bind(val.data, axes=axes)
axes_size = np.prod(np.array([shape[idx] for idx in axes]))
scale = lax.integer_pow(val.scale, axes_size)
return ScaledArray(data, scale)


@core.register_scaled_lax_op
def scaled_reduce_max(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
data = lax.reduce_max_p.bind(val.data, axes=axes)
# unchanged scaling.
return ScaledArray(data, val.scale)


@core.register_scaled_lax_op
def scaled_reduce_min(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
data = lax.reduce_min_p.bind(val.data, axes=axes)
# unchanged scaling.
return ScaledArray(data, val.scale)
29 changes: 28 additions & 1 deletion tests/lax/test_scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import chex
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized
from jax import lax

from jax_scaled_arithmetics.core import ScaledArray, scaled_array
from jax_scaled_arithmetics.core import ScaledArray, find_registered_scaled_op, scaled_array
from jax_scaled_arithmetics.lax import (
scaled_add,
scaled_broadcast_in_dim,
Expand Down Expand Up @@ -93,3 +95,28 @@ def test__scaled_dot_general__proper_scaling(self):
assert out.dtype == lhs.dtype
npt.assert_almost_equal(out.scale, lhs.scale * rhs.scale * np.sqrt(5))
npt.assert_array_almost_equal(out, np.asarray(lhs) @ np.asarray(rhs))


class ScaledTranslationReducePrimitivesTests(chex.TestCase):
def setUp(self):
super().setUp()
# Use random state for reproducibility!
self.rs = np.random.RandomState(42)

@parameterized.parameters(
{"reduce_prim": lax.reduce_sum_p, "expected_scale": 2 * np.sqrt(5)},
{"reduce_prim": lax.reduce_prod_p, "expected_scale": 2**5},
{"reduce_prim": lax.reduce_min_p, "expected_scale": 2},
{"reduce_prim": lax.reduce_max_p, "expected_scale": 2},
)
def test__scaled_reduce__single_axis__proper_scaling(self, reduce_prim, expected_scale):
axes = (0,)
val = scaled_array(self.rs.rand(5), 2.0, dtype=np.float32)
scaled_reduce_op, _ = find_registered_scaled_op(reduce_prim)
out = scaled_reduce_op(val, axes=axes)

assert isinstance(out, ScaledArray)
assert out.shape == ()
assert out.dtype == val.dtype
npt.assert_almost_equal(out.scale, expected_scale)
npt.assert_array_almost_equal(out, reduce_prim.bind(np.asarray(val), axes=axes))

0 comments on commit 896c8d9

Please sign in to comment.