Skip to content

Commit

Permalink
Implement basic scaled operations for MLP model. (#16)
Browse files Browse the repository at this point in the history
* Implement basic scaled operations for MLP model.

Adding scaled translation/implementation for: `add`, `sub` and `dot_general`.
Should allow training minimal MLP model.

* Extend JAX scaled dot to dot_general.

---------

Co-authored-by: samho <>
  • Loading branch information
balancap authored Nov 15, 2023
1 parent bba779d commit 6373362
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 1 deletion.
59 changes: 58 additions & 1 deletion jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Optional, Sequence
from typing import Any, Optional, Sequence, Tuple

import jax.numpy as jnp
import numpy as np
from jax import lax

from jax_scaled_arithmetics import core
Expand Down Expand Up @@ -50,3 +51,59 @@ def scaled_transpose(A: ScaledArray, permutation: Sequence[int]) -> ScaledArray:
@core.register_scaled_lax_op
def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(A.data * B.data, A.scale * B.scale)


@core.register_scaled_lax_op
def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray:
# Only supporting floating scale right now.
assert A.scale.dtype == B.scale.dtype
assert np.issubdtype(A.scale, np.floating)
# TODO: what happens to `sqrt` for non-floating scale?
output_scale = lax.sqrt(A.scale**2 + B.scale**2)
# check correct type output if mismatch between data and scale precision
output_data = (A.scale / output_scale) * A.data + (B.scale / output_scale) * B.data
return ScaledArray(output_data, output_scale)


@core.register_scaled_lax_op
def scaled_sub(A: ScaledArray, B: ScaledArray) -> ScaledArray:
# Only supporting floating scale right now.
assert A.scale.dtype == B.scale.dtype
assert np.issubdtype(A.scale, np.floating)
# TODO: what happens to `sqrt` for non-floating scale?
output_scale = lax.sqrt(A.scale**2 + B.scale**2)
# check correct type output if mismatch between data and scale precision
output_data = (A.scale / output_scale) * A.data - (B.scale / output_scale) * B.data
return ScaledArray(output_data, output_scale)


@core.register_scaled_lax_op
def scaled_dot_general(
lhs: ScaledArray,
rhs: ScaledArray,
dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]], Tuple[Sequence[int], Sequence[int]]],
precision: Any = None,
preferred_element_type: Optional[DTypeLike] = None,
) -> ScaledArray:
# Checks on `dot_general` arguments. Only supporting a subset right now.
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)) = dimension_numbers
assert len(lhs_batch_dims) == 0
assert len(rhs_batch_dims) == 0
assert len(lhs_contracting_dims) == 1
assert len(rhs_contracting_dims) == 1

contracting_dim_size = lhs.shape[lhs_contracting_dims[0]]
# "unit scaling" rule, based on the contracting axis.
contracting_rescale = np.sqrt(contracting_dim_size).astype(lhs.dtype)
output_scale = lhs.scale * rhs.scale * contracting_rescale
output_data = (
lax.dot_general(
lhs.data,
rhs.data,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
/ contracting_rescale
)
return ScaledArray(output_data, output_scale)
30 changes: 30 additions & 0 deletions tests/lax/test_scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

from jax_scaled_arithmetics.core import ScaledArray, scaled_array
from jax_scaled_arithmetics.lax import (
scaled_add,
scaled_broadcast_in_dim,
scaled_concatenate,
scaled_convert_element_type,
scaled_dot_general,
scaled_mul,
scaled_slice,
scaled_sub,
scaled_transpose,
)

Expand Down Expand Up @@ -58,3 +61,30 @@ def test__scaled_mul__proper_scaling(self):
assert isinstance(z, ScaledArray)
assert z.scale == 6
npt.assert_array_almost_equal(z, np.asarray(x) * np.asarray(y))

def test__scaled_add__proper_scaling(self):
x = scaled_array([-1.0, 2.0], 3.0, dtype=np.float32)
y = scaled_array([1.5, 4.5], 2.0, dtype=np.float32)
z = scaled_add(x, y)
assert isinstance(z, ScaledArray)
assert z.dtype == x.dtype
npt.assert_almost_equal(z.scale, np.sqrt(4.0 + 9.0))
npt.assert_array_almost_equal(z, np.asarray(x) + np.asarray(y))

def test__scaled_sub__proper_scaling(self):
x = scaled_array([-1.0, 2.0], 3.0, dtype=np.float32)
y = scaled_array([1.5, 4.5], 2.0, dtype=np.float32)
z = scaled_sub(x, y)
assert isinstance(z, ScaledArray)
assert z.dtype == x.dtype
npt.assert_almost_equal(z.scale, np.sqrt(4.0 + 9.0))
npt.assert_array_almost_equal(z, np.asarray(x) - np.asarray(y))

def test__scaled_dot_general__proper_scaling(self):
lhs = scaled_array(np.random.rand(3, 5), 2.0, dtype=np.float32)
rhs = scaled_array(np.random.rand(5, 2), 3.0, dtype=np.float32)
out = scaled_dot_general(lhs, rhs, (((1,), (0,)), ((), ())))
assert isinstance(out, ScaledArray)
assert out.dtype == lhs.dtype
npt.assert_almost_equal(out.scale, lhs.scale * rhs.scale * np.sqrt(5))
npt.assert_array_almost_equal(out, np.asarray(lhs) @ np.asarray(rhs))

0 comments on commit 6373362

Please sign in to comment.