From ff60e343270e0f135d93748f255d76c6f3f1ea9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Chollet?= Date: Wed, 20 Sep 2023 15:18:47 -0700 Subject: [PATCH] Some dtype fixes (#935) * Some dtype fixes * Nits --- keras_core/backend/jax/numpy.py | 7 +++++ keras_core/backend/numpy/numpy.py | 18 ++++++++++++- keras_core/backend/tensorflow/numpy.py | 11 ++++++++ keras_core/backend/torch/numpy.py | 10 ++++++- keras_core/ops/numpy_test.py | 36 +++++++++++++++++++++++--- 5 files changed, 77 insertions(+), 5 deletions(-) diff --git a/keras_core/backend/jax/numpy.py b/keras_core/backend/jax/numpy.py index bd1a11c3d..6471075e0 100644 --- a/keras_core/backend/jax/numpy.py +++ b/keras_core/backend/jax/numpy.py @@ -113,6 +113,13 @@ def append( def arange(start, stop=None, step=1, dtype=None): + if dtype is None: + if hasattr(start, "dtype"): + dtype = start.dtype + elif isinstance(start, int): + dtype = "int32" + else: + dtype = config.floatx() return jnp.arange(start, stop, step=step, dtype=dtype) diff --git a/keras_core/backend/numpy/numpy.py b/keras_core/backend/numpy/numpy.py index 33f46d364..d2ee78aba 100644 --- a/keras_core/backend/numpy/numpy.py +++ b/keras_core/backend/numpy/numpy.py @@ -1,5 +1,8 @@ import numpy as np +from keras_core.backend import config +from keras_core.backend import standardize_dtype + def add(x1, x2): return np.add(x1, x2) @@ -77,6 +80,13 @@ def append( def arange(start, stop=None, step=None, dtype=None): + if dtype is None: + if hasattr(start, "dtype"): + dtype = start.dtype + elif isinstance(start, int): + dtype = "int32" + else: + dtype = config.floatx() return np.arange(start, stop, step=step, dtype=dtype) @@ -124,6 +134,7 @@ def argsort(x, axis=-1): def array(x, dtype=None): + dtype = dtype or config.floatx() return np.array(x, dtype=dtype) @@ -271,6 +282,7 @@ def floor(x): def full(shape, fill_value, dtype=None): + dtype = dtype or config.floatx() return np.full(shape, fill_value, dtype=dtype) @@ -592,7 +604,11 @@ def square(x): def sqrt(x): - return np.sqrt(x) + dtype = None + if hasattr(x, "dtype"): + if standardize_dtype(x.dtype).startswith("int"): + dtype = config.floatx() + return np.sqrt(x, dtype=dtype) def squeeze(x, axis=None): diff --git a/keras_core/backend/tensorflow/numpy.py b/keras_core/backend/tensorflow/numpy.py index 08ae6ca9d..eef26c5a0 100644 --- a/keras_core/backend/tensorflow/numpy.py +++ b/keras_core/backend/tensorflow/numpy.py @@ -5,6 +5,7 @@ import tensorflow as tf from tensorflow.experimental import numpy as tfnp +from keras_core.backend import config from keras_core.backend.tensorflow.core import convert_to_tensor @@ -176,6 +177,13 @@ def append( def arange(start, stop=None, step=1, dtype=None): # tfnp.arange has trouble with dynamic Tensors in compiled function. # tf.range does not. + if dtype is None: + if hasattr(start, "dtype"): + dtype = start.dtype + elif isinstance(start, int): + dtype = "int32" + else: + dtype = config.floatx() return tf.range(start, stop, delta=step, dtype=dtype) @@ -749,6 +757,9 @@ def square(x): def sqrt(x): + x = convert_to_tensor(x) + if tf.as_dtype(x.dtype).is_integer: + x = tf.cast(x, dtype=config.floatx()) return tfnp.sqrt(x) diff --git a/keras_core/backend/torch/numpy.py b/keras_core/backend/torch/numpy.py index da444be70..04dd695aa 100644 --- a/keras_core/backend/torch/numpy.py +++ b/keras_core/backend/torch/numpy.py @@ -1,6 +1,7 @@ import numpy as np import torch +from keras_core.backend import config from keras_core.backend.torch.core import cast from keras_core.backend.torch.core import convert_to_tensor from keras_core.backend.torch.core import get_device @@ -91,7 +92,7 @@ def zeros(shape, dtype="float32"): def zeros_like(x, dtype=None): x = convert_to_tensor(x) - dtype = to_torch_dtype(dtype) + dtype = to_torch_dtype(dtype or x.dtype) return torch.zeros_like(x, dtype=dtype) @@ -160,6 +161,13 @@ def append( def arange(start, stop=None, step=1, dtype=None): + if dtype is None: + if hasattr(start, "dtype"): + dtype = start.dtype + elif isinstance(start, int): + dtype = "int32" + else: + dtype = config.floatx() dtype = to_torch_dtype(dtype) if stop is None: return torch.arange(end=start, dtype=dtype, device=get_device()) diff --git a/keras_core/ops/numpy_test.py b/keras_core/ops/numpy_test.py index 83e333563..c8d61767c 100644 --- a/keras_core/ops/numpy_test.py +++ b/keras_core/ops/numpy_test.py @@ -3571,9 +3571,37 @@ def test_split(self): self.assertEqual(len(knp.Split(2)(x)), 2) def test_sqrt(self): - x = np.array([[1, 4, 9], [16, 25, 36]]) - self.assertAllClose(knp.sqrt(x), np.sqrt(x)) - self.assertAllClose(knp.Sqrt()(x), np.sqrt(x)) + x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32") + ref_y = np.sqrt(x) + y = knp.sqrt(x) + self.assertEqual(standardize_dtype(y.dtype), "float32") + self.assertAllClose(y, ref_y) + y = knp.Sqrt()(x) + self.assertEqual(standardize_dtype(y.dtype), "float32") + self.assertAllClose(y, ref_y) + + @pytest.mark.skipif( + backend.backend() == "jax", reason="JAX does not support float64." + ) + def test_sqrt_float64(self): + x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float64") + ref_y = np.sqrt(x) + y = knp.sqrt(x) + self.assertEqual(standardize_dtype(y.dtype), "float64") + self.assertAllClose(y, ref_y) + y = knp.Sqrt()(x) + self.assertEqual(standardize_dtype(y.dtype), "float64") + self.assertAllClose(y, ref_y) + + def test_sqrt_int32(self): + x = np.array([[1, 4, 9], [16, 25, 36]], dtype="int32") + ref_y = np.sqrt(x) + y = knp.sqrt(x) + self.assertEqual(standardize_dtype(y.dtype), "float32") + self.assertAllClose(y, ref_y) + y = knp.Sqrt()(x) + self.assertEqual(standardize_dtype(y.dtype), "float32") + self.assertAllClose(y, ref_y) def test_stack(self): x = np.array([[1, 2, 3], [3, 2, 1]]) @@ -3704,6 +3732,8 @@ def test_arange(self): self.assertAllClose(knp.Arange()(3, 7), np.arange(3, 7)) self.assertAllClose(knp.Arange()(3, 7, 2), np.arange(3, 7, 2)) + self.assertEqual(standardize_dtype(knp.arange(3).dtype), "int32") + def test_full(self): self.assertAllClose(knp.full([2, 3], 0), np.full([2, 3], 0)) self.assertAllClose(knp.full([2, 3], 0.1), np.full([2, 3], 0.1))