Skip to content

Commit

Permalink
Implement basic boolean scaled operators: eq, ne, ...
Browse files Browse the repository at this point in the history
Note: at the moment, the binary boolean scaled translation is not
optimized, using a casting to `float32` by default!
  • Loading branch information
balancap committed Nov 21, 2023
1 parent 0fc3400 commit 511d8b5
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
50 changes: 50 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
Expand Down Expand Up @@ -145,3 +146,52 @@ def scaled_reduce_min(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
data = lax.reduce_min_p.bind(val.data, axes=axes)
# unchanged scaling.
return ScaledArray(data, val.scale)


@core.register_scaled_lax_op
def scaled_is_finite(val: ScaledArray) -> jax.Array:
assert isinstance(val, ScaledArray)
if np.issubdtype(val.scale.dtype, np.integer):
# Integer scale case => only check the data component.
return lax.is_finite(val.data)
# Both data & scale need to be finite!
return lax.and_p.bind(lax.is_finite(val.data), lax.is_finite(val.scale))


def scaled_boolean_binary_op(lhs: ScaledArray, rhs: ScaledArray, prim: jax.core.Primitive) -> jax.Array:
"""Generic implementation of any boolean binary operation."""
assert isinstance(lhs, ScaledArray)
assert isinstance(rhs, ScaledArray)
# FIXME: fix this absolute horror!
# TODO: use max scale + special case for scalars.
return prim.bind(lhs.to_array(dtype=np.float32), rhs.to_array(dtype=np.float32))


@core.register_scaled_lax_op
def scaled_eq(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array:
return scaled_boolean_binary_op(lhs, rhs, lax.eq_p)


@core.register_scaled_lax_op
def scaled_ne(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array:
return scaled_boolean_binary_op(lhs, rhs, lax.ne_p)


@core.register_scaled_lax_op
def scaled_gt(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array:
return scaled_boolean_binary_op(lhs, rhs, lax.gt_p)


@core.register_scaled_lax_op
def scaled_ge(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array:
return scaled_boolean_binary_op(lhs, rhs, lax.ge_p)


@core.register_scaled_lax_op
def scaled_lt(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array:
return scaled_boolean_binary_op(lhs, rhs, lax.lt_p)


@core.register_scaled_lax_op
def scaled_le(lhs: ScaledArray, rhs: ScaledArray) -> jax.Array:
return scaled_boolean_binary_op(lhs, rhs, lax.le_p)
22 changes: 22 additions & 0 deletions tests/lax/test_scaled_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import chex
import jax
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized
Expand All @@ -12,6 +13,7 @@
scaled_concatenate,
scaled_convert_element_type,
scaled_dot_general,
scaled_is_finite,
scaled_mul,
scaled_slice,
scaled_sub,
Expand Down Expand Up @@ -120,3 +122,23 @@ def test__scaled_reduce__single_axis__proper_scaling(self, reduce_prim, expected
assert out.dtype == val.dtype
npt.assert_almost_equal(out.scale, expected_scale)
npt.assert_array_almost_equal(out, reduce_prim.bind(np.asarray(val), axes=axes))


class ScaledTranslationBooleanPrimitivesTests(chex.TestCase):
def setUp(self):
super().setUp()
# Use random state for reproducibility!
self.rs = np.random.RandomState(42)

@parameterized.parameters(
{"val": scaled_array([2, 3], 2.0, dtype=np.float32), "expected_out": [True, True]},
# Supporting `int` scale as well.
{"val": scaled_array([2, np.inf], 2, dtype=np.float32), "expected_out": [True, False]},
{"val": scaled_array([2, 3], np.nan, dtype=np.float32), "expected_out": [False, False]},
{"val": scaled_array([np.nan, 3], 3.0, dtype=np.float32), "expected_out": [False, True]},
)
def test__scaled_is_finite__proper_result(self, val, expected_out):
out = scaled_is_finite(val)
assert isinstance(out, jax.Array)
assert out.dtype == np.bool_
npt.assert_array_equal(out, expected_out)

0 comments on commit 511d8b5

Please sign in to comment.