Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement reduce_sum/prod/max/min scaled translation rules. #26

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .datatype import DTypeLike, ScaledArray, Shape, is_scaled_leaf, scaled_array # noqa: F401
from .interpreters import ScaledPrimitiveType, autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401
from .interpreters import ( # noqa: F401
ScaledPrimitiveType,
autoscale,
find_registered_scaled_op,
register_scaled_lax_op,
register_scaled_op,
)
47 changes: 29 additions & 18 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from enum import IntEnum
from functools import wraps
from typing import Any, Dict
from typing import Any, Dict, Tuple

import jax
import numpy as np
Expand All @@ -10,7 +10,24 @@

from .datatype import NDArray, ScaledArray

_scaled_ops_registry: Dict[core.Primitive, Any] = {}

class ScaledPrimitiveType(IntEnum):
"""Scale (JAX) primitive type.

This enum described the behaviour when `autoscale` is
tracing the graph.

FORWARD: Forwarding scaling => only used if scaled inputs.
Default behaviour.
ALWAYS_SCALE: Always use scaled version.
"""

NEVER = 0
FORWARD = 1
ALWAYS_SCALE = 2


_scaled_ops_registry: Dict[core.Primitive, Tuple[Any, ScaledPrimitiveType]] = {}


def _get_lax_prim(scaled_func: Any) -> core.Primitive:
Expand Down Expand Up @@ -43,22 +60,6 @@ def promote_scalar_to_scaled_array(val: Any) -> ScaledArray:
return ScaledArray(data=onedata, scale=val)


class ScaledPrimitiveType(IntEnum):
"""Scale (JAX) primitive type.

This enum described the behaviour when `autoscale` is
tracing the graph.

FORWARD: Forwarding scaling => only used if scaled inputs.
Default behaviour.
ALWAYS_SCALE: Always use scaled version.
"""

NEVER = 0
FORWARD = 1
ALWAYS_SCALE = 2


def numpy_constant_scaled_array(val: NDArray[Any]) -> ScaledArray:
"""Get the ScaledArray corresponding to a Numpy constant.

Expand Down Expand Up @@ -102,6 +103,16 @@ def register_scaled_lax_op(scaled_func):
return scaled_func


def find_registered_scaled_op(prim: core.Primitive) -> Tuple[Any, ScaledPrimitiveType]:
"""Find a registered JAX scaled operation/translation. Returns (None, None) if
the primitive does not have a scaled translation registered.

Args:
prim: JAX primitive.
"""
return _scaled_ops_registry.get(prim, (None, ScaledPrimitiveType.NEVER))


def autoscale(fun):
"""`autoscale` JAX graph transformation.

Expand Down
38 changes: 38 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,41 @@ def scaled_dot_general(
/ contracting_rescale
)
return ScaledArray(output_data, output_scale)


@core.register_scaled_lax_op
def scaled_reduce_sum(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
shape = val.shape
axes_size = np.array([shape[idx] for idx in axes])
# Rescale data component following reduction axes.
axes_rescale = np.sqrt(np.prod(axes_size))
data = lax.reduce_sum_p.bind(val.data, axes=axes) / axes_rescale
outscale = val.scale * axes_rescale
return ScaledArray(data, outscale)


@core.register_scaled_lax_op
def scaled_reduce_prod(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
shape = val.shape
data = lax.reduce_prod_p.bind(val.data, axes=axes)
axes_size = np.prod(np.array([shape[idx] for idx in axes]))
scale = lax.integer_pow(val.scale, axes_size)
return ScaledArray(data, scale)


@core.register_scaled_lax_op
def scaled_reduce_max(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
data = lax.reduce_max_p.bind(val.data, axes=axes)
# unchanged scaling.
return ScaledArray(data, val.scale)


@core.register_scaled_lax_op
def scaled_reduce_min(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
data = lax.reduce_min_p.bind(val.data, axes=axes)
# unchanged scaling.
return ScaledArray(data, val.scale)
29 changes: 28 additions & 1 deletion tests/lax/test_scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import chex
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized
from jax import lax

from jax_scaled_arithmetics.core import ScaledArray, scaled_array
from jax_scaled_arithmetics.core import ScaledArray, find_registered_scaled_op, scaled_array
from jax_scaled_arithmetics.lax import (
scaled_add,
scaled_broadcast_in_dim,
Expand Down Expand Up @@ -93,3 +95,28 @@ def test__scaled_dot_general__proper_scaling(self):
assert out.dtype == lhs.dtype
npt.assert_almost_equal(out.scale, lhs.scale * rhs.scale * np.sqrt(5))
npt.assert_array_almost_equal(out, np.asarray(lhs) @ np.asarray(rhs))


class ScaledTranslationReducePrimitivesTests(chex.TestCase):
def setUp(self):
super().setUp()
# Use random state for reproducibility!
self.rs = np.random.RandomState(42)

@parameterized.parameters(
{"reduce_prim": lax.reduce_sum_p, "expected_scale": 2 * np.sqrt(5)},
{"reduce_prim": lax.reduce_prod_p, "expected_scale": 2**5},
{"reduce_prim": lax.reduce_min_p, "expected_scale": 2},
{"reduce_prim": lax.reduce_max_p, "expected_scale": 2},
)
def test__scaled_reduce__single_axis__proper_scaling(self, reduce_prim, expected_scale):
axes = (0,)
val = scaled_array(self.rs.rand(5), 2.0, dtype=np.float32)
scaled_reduce_op, _ = find_registered_scaled_op(reduce_prim)
out = scaled_reduce_op(val, axes=axes)

assert isinstance(out, ScaledArray)
assert out.shape == ()
assert out.dtype == val.dtype
npt.assert_almost_equal(out.scale, expected_scale)
npt.assert_array_almost_equal(out, reduce_prim.bind(np.asarray(val), axes=axes))