diff --git a/keras_core/backend/jax/numpy.py b/keras_core/backend/jax/numpy.py index c45a33ed8..ca7705078 100644 --- a/keras_core/backend/jax/numpy.py +++ b/keras_core/backend/jax/numpy.py @@ -207,6 +207,12 @@ def diagonal(x, offset=0, axis1=0, axis2=1): ) +def digitize(x, bins): + x = convert_to_tensor(x) + bins = convert_to_tensor(bins) + return cast(jnp.digitize(x, bins), "int64") + + def dot(x, y): return jnp.dot(x, y) diff --git a/keras_core/backend/numpy/numpy.py b/keras_core/backend/numpy/numpy.py index 23d8068ca..ad41a81c1 100644 --- a/keras_core/backend/numpy/numpy.py +++ b/keras_core/backend/numpy/numpy.py @@ -199,6 +199,10 @@ def diagonal(x, offset=0, axis1=0, axis2=1): ) +def digitize(x, bins): + return np.digitize(x, bins) + + def dot(x, y): return np.dot(x, y) diff --git a/keras_core/backend/tensorflow/numpy.py b/keras_core/backend/tensorflow/numpy.py index 954e1e08a..23402887e 100644 --- a/keras_core/backend/tensorflow/numpy.py +++ b/keras_core/backend/tensorflow/numpy.py @@ -210,6 +210,12 @@ def diagonal(x, offset=0, axis1=0, axis2=1): ) +def digitize(x, bins): + x = convert_to_tensor(x) + bins = list(bins) + return tf.cast(tf.raw_ops.Bucketize(input=x, boundaries=bins), "int64") + + def dot(x, y): return tfnp.dot(x, y) diff --git a/keras_core/backend/torch/numpy.py b/keras_core/backend/torch/numpy.py index 0b3606d7a..e1896cab8 100644 --- a/keras_core/backend/torch/numpy.py +++ b/keras_core/backend/torch/numpy.py @@ -327,6 +327,12 @@ def diagonal(x, offset=0, axis1=0, axis2=1): ) +def digitize(x, bins): + x = convert_to_tensor(x) + bins = convert_to_tensor(bins) + return cast(torch.bucketize(x, bins, right=True), "int32") + + def dot(x, y): x, y = convert_to_tensor(x), convert_to_tensor(y) if x.ndim == 0 or y.ndim == 0: diff --git a/keras_core/ops/numpy.py b/keras_core/ops/numpy.py index 820357f49..99925f10f 100644 --- a/keras_core/ops/numpy.py +++ b/keras_core/ops/numpy.py @@ -34,6 +34,7 @@ diag diagonal diff +digitize divide dot dtype @@ -1534,6 +1535,43 @@ def diagonal(x, offset=0, axis1=0, axis2=1): ) +class Digitize(Operation): + def call(self, x, bins): + return backend.numpy.digitize(x, bins) + + def compute_output_spec(self, x, bins): + bins_shape = bins.shape + if len(bins_shape) > 1: + raise ValueError( + "`bins` must be an array of one dimension, but recieved `bins` " + f"of shape {bins_shape}." + ) + return KerasTensor(x.shape, dtype="int") + + +@keras_core_export(["keras_core.ops.digitize", "keras_core.ops.numpy.digitize"]) +def digitize(x, bins): + """Returns the indices of the bins to which each value in `x` belongs. + + Args: + x: Input array to be binned. + bins: Array of bins. It has to be one-dimensional and monotonically + increasing. + + Returns: + Output array of indices, of same shape as x. + + Examples: + >>> x = np.array([0.0, 1.0, 3.0, 1.6]) + >>> bins = np.array([0.0, 3.0, 4.5, 7.0]) + >>> keras_core.ops.digitize(x, bins) + array([1, 1, 2, 1]) + """ + if any_symbolic_tensors((x, bins)): + return Digitize().symbolic_call(x, bins) + return backend.numpy.digitize(x, bins) + + class Dot(Operation): def call(self, x1, x2): return backend.numpy.dot(x1, x2) diff --git a/keras_core/ops/numpy_test.py b/keras_core/ops/numpy_test.py index df97cdd3f..4ea9dba37 100644 --- a/keras_core/ops/numpy_test.py +++ b/keras_core/ops/numpy_test.py @@ -3,6 +3,7 @@ from keras_core import backend from keras_core import testing +from keras_core.backend.common import standardize_dtype from keras_core.backend.common.keras_tensor import KerasTensor from keras_core.ops import numpy as knp @@ -660,6 +661,17 @@ def test_xor(self): y = KerasTensor([2, 3, 4]) knp.logical_xor(x, y) + def test_digitize(self): + x = KerasTensor((2, 3)) + bins = KerasTensor((3,)) + self.assertEqual(knp.digitize(x, bins).shape, (2, 3)) + self.assertTrue(knp.digitize(x, bins).dtype == standardize_dtype("int")) + + with self.assertRaises(ValueError): + x = KerasTensor([2, 3]) + bins = KerasTensor([2, 3, 4]) + knp.digitize(x, bins) + class NumpyOneInputOpsDynamicShapeTest(testing.TestCase): def test_mean(self): @@ -2092,6 +2104,46 @@ def test_where(self): self.assertAllClose(knp.where(x > 1, x, y), np.where(x > 1, x, y)) self.assertAllClose(knp.Where()(x > 1, x, y), np.where(x > 1, x, y)) + def test_digitize(self): + x = np.array([0.0, 1.0, 3.0, 1.6]) + bins = np.array([0.0, 3.0, 4.5, 7.0]) + self.assertAllClose(knp.digitize(x, bins), np.digitize(x, bins)) + self.assertAllClose(knp.Digitize()(x, bins), np.digitize(x, bins)) + self.assertTrue( + standardize_dtype(knp.digitize(x, bins).dtype) + == standardize_dtype("int") + ) + self.assertTrue( + standardize_dtype(knp.Digitize()(x, bins).dtype) + == standardize_dtype("int") + ) + + x = np.array([0.2, 6.4, 3.0, 1.6]) + bins = np.array([0.0, 1.0, 2.5, 4.0, 10.0]) + self.assertAllClose(knp.digitize(x, bins), np.digitize(x, bins)) + self.assertAllClose(knp.Digitize()(x, bins), np.digitize(x, bins)) + self.assertTrue( + standardize_dtype(knp.digitize(x, bins).dtype) + == standardize_dtype("int") + ) + self.assertTrue( + standardize_dtype(knp.Digitize()(x, bins).dtype) + == standardize_dtype("int") + ) + + x = np.array([1, 4, 10, 15]) + bins = np.array([4, 10, 14, 15]) + self.assertAllClose(knp.digitize(x, bins), np.digitize(x, bins)) + self.assertAllClose(knp.Digitize()(x, bins), np.digitize(x, bins)) + self.assertTrue( + standardize_dtype(knp.digitize(x, bins).dtype) + == standardize_dtype("int") + ) + self.assertTrue( + standardize_dtype(knp.Digitize()(x, bins).dtype) + == standardize_dtype("int") + ) + class NumpyOneInputOpsCorrectnessTest(testing.TestCase): def test_mean(self):