Skip to content

Commit

Permalink
Re-organize scaled ops translations into common and L2 rules. (#57)
Browse files Browse the repository at this point in the history
Properly separating common ops, independent of the strategy, and L2/Gaussian
scaling ops/translation (which potentially can be altered).
  • Loading branch information
balancap authored Dec 18, 2023
1 parent ea8c961 commit 235806a
Show file tree
Hide file tree
Showing 5 changed files with 363 additions and 347 deletions.
3 changes: 2 additions & 1 deletion jax_scaled_arithmetics/lax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .base_scaling_primitives import set_scaling, set_scaling_p, stop_scaling, stop_scaling_p # noqa: F401
from .scaled_ops import * # noqa: F401, F403
from .scaled_ops_common import * # noqa: F401, F403
from .scaled_ops_l2 import * # noqa: F401, F403
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Optional, Sequence

import jax
import jax.core
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax._src.ad_util import add_any_p

from jax_scaled_arithmetics import core
from jax_scaled_arithmetics.core import (
Array,
DTypeLike,
ScaledArray,
Shape,
as_scaled_array,
is_static_zero,
register_scaled_op,
)
from jax_scaled_arithmetics.core import Array, DTypeLike, ScaledArray, Shape, as_scaled_array, is_static_zero

from .base_scaling_primitives import scaled_set_scaling

Expand Down Expand Up @@ -158,191 +149,6 @@ def scaled_div(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
return ScaledArray(lhs.data / rhs.data, lhs.scale / rhs.scale)


@core.register_scaled_lax_op
def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray:
# TODO: understand when promotion is really required?
A, B = as_scaled_array((A, B)) # type:ignore
check_scalar_scales(A, B)
A, B = promote_scale_types(A, B)
assert np.issubdtype(A.scale.dtype, np.floating)
# TODO: what happens to `sqrt` for non-floating scale?
output_scale = lax.sqrt(A.scale * A.scale + B.scale * B.scale)
# Output dtype => promotion of A and B dtypes.
outdtype = jnp.promote_types(A.dtype, B.dtype)
Arescale = (A.scale / output_scale).astype(outdtype)
Brescale = (B.scale / output_scale).astype(outdtype)
# check correct type output if mismatch between data and scale precision
output_data = Arescale * A.data + Brescale * B.data
return ScaledArray(output_data, output_scale)


# TODO: understand difference between `add` and `add_anys`
register_scaled_op(add_any_p, scaled_add)


@core.register_scaled_lax_op
def scaled_sub(A: ScaledArray, B: ScaledArray) -> ScaledArray:
# TODO: understand when promotion is really required?
A, B = as_scaled_array((A, B)) # type:ignore
check_scalar_scales(A, B)
A, B = promote_scale_types(A, B)
assert np.issubdtype(A.scale.dtype, np.floating)
# TODO: what happens to `sqrt` for non-floating scale?
output_scale = lax.sqrt(A.scale * A.scale + B.scale * B.scale)
# Output dtype => promotion of A and B dtypes.
outdtype = jnp.promote_types(A.dtype, B.dtype)
Arescale = (A.scale / output_scale).astype(outdtype)
Brescale = (B.scale / output_scale).astype(outdtype)
# check correct type output if mismatch between data and scale precision
output_data = Arescale * A.data - Brescale * B.data
return ScaledArray(output_data, output_scale)


@core.register_scaled_lax_op
def scaled_dot_general(
lhs: ScaledArray,
rhs: ScaledArray,
dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]], Tuple[Sequence[int], Sequence[int]]],
precision: Any = None,
preferred_element_type: Optional[DTypeLike] = None,
) -> ScaledArray:
# Checks on `dot_general` arguments. Only supporting a subset right now.
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)) = dimension_numbers
assert len(lhs_batch_dims) == 0
assert len(rhs_batch_dims) == 0
assert len(lhs_contracting_dims) == 1
assert len(rhs_contracting_dims) == 1

contracting_dim_size = lhs.shape[lhs_contracting_dims[0]]
# "unit scaling" rule, based on the contracting axis.
outscale_dtype = jnp.promote_types(lhs.scale.dtype, rhs.scale.dtype)
contracting_rescale = np.sqrt(contracting_dim_size)
output_scale = lhs.scale * rhs.scale * contracting_rescale.astype(outscale_dtype)
# NOTE: need to be a bit careful about scale promotion?
output_data = lax.dot_general(
lhs.data,
rhs.data,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
output_data = output_data / contracting_rescale.astype(output_data.dtype)
return ScaledArray(output_data, output_scale)


@core.register_scaled_lax_op
def scaled_conv_general_dilated(lhs: ScaledArray, rhs: ScaledArray, **params) -> ScaledArray:
assert isinstance(lhs, ScaledArray)
assert isinstance(rhs, ScaledArray)
data = lax.conv_general_dilated_p.bind(lhs.data, rhs.data, **params)
# FIXME: should we change scaling if e.g. window > 3?
return ScaledArray(data, lhs.scale * rhs.scale)


@core.register_scaled_lax_op
def scaled_reduce_sum(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
shape = val.shape
axes_size = np.array([shape[idx] for idx in axes])
# Rescale data component following reduction axes.
axes_rescale = np.sqrt(np.prod(axes_size))
data = lax.reduce_sum_p.bind(val.data, axes=axes) / axes_rescale.astype(val.data.dtype)
outscale = val.scale * axes_rescale.astype(val.scale.dtype)
return ScaledArray(data, outscale)


@core.register_scaled_lax_op
def scaled_reduce_prod(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
shape = val.shape
data = lax.reduce_prod_p.bind(val.data, axes=axes)
axes_size = np.prod(np.array([shape[idx] for idx in axes]))
scale = lax.integer_pow(val.scale, axes_size)
return ScaledArray(data, scale)


@core.register_scaled_lax_op
def scaled_reduce_max(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
data = lax.reduce_max_p.bind(val.data, axes=axes)
# unchanged scaling.
return ScaledArray(data, val.scale)


@core.register_scaled_lax_op
def scaled_reduce_min(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
data = lax.reduce_min_p.bind(val.data, axes=axes)
# unchanged scaling.
return ScaledArray(data, val.scale)


@core.register_scaled_lax_op
def scaled_reduce_window_sum(
val: ScaledArray,
window_dimensions: Any,
window_strides: Any,
padding: Any,
base_dilation: Any,
window_dilation: Any,
) -> ScaledArray:
assert isinstance(val, ScaledArray)
data = lax.reduce_window_sum_p.bind(
val.data,
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding,
base_dilation=base_dilation,
window_dilation=window_dilation,
)
# FIXME: should we change scaling if e.g. window > 3?
return ScaledArray(data, val.scale)


@core.register_scaled_lax_op
def scaled_reduce_window_min(
val: ScaledArray,
window_dimensions: Any,
window_strides: Any,
padding: Any,
base_dilation: Any,
window_dilation: Any,
) -> ScaledArray:
assert isinstance(val, ScaledArray)
data = lax.reduce_window_min_p.bind(
val.data,
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding,
base_dilation=base_dilation,
window_dilation=window_dilation,
)
# unchanged scaling.
return ScaledArray(data, val.scale)


@core.register_scaled_lax_op
def scaled_reduce_window_max(
val: ScaledArray,
window_dimensions: Any,
window_strides: Any,
padding: Any,
base_dilation: Any,
window_dilation: Any,
) -> ScaledArray:
assert isinstance(val, ScaledArray)
data = lax.reduce_window_max_p.bind(
val.data,
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding,
base_dilation=base_dilation,
window_dilation=window_dilation,
)
# unchanged scaling.
return ScaledArray(data, val.scale)


@core.register_scaled_lax_op
def scaled_is_finite(val: ScaledArray) -> Array:
assert isinstance(val, ScaledArray)
Expand Down
Loading

0 comments on commit 235806a

Please sign in to comment.