From 47bf22d61fcf70255012802bb0772dbfdff002c4 Mon Sep 17 00:00:00 2001 From: Francesco Zanetta <62377868+frazane@users.noreply.github.com> Date: Sat, 27 Jan 2024 00:01:16 +0100 Subject: [PATCH] Add `erfinv` op (#19107) * erfinv backends implementations * erfinv frontend op * add test * fix variable names * naming convention * fix symbolic call --- keras/backend/jax/math.py | 4 ++++ keras/backend/numpy/math.py | 4 ++++ keras/backend/tensorflow/math.py | 4 ++++ keras/backend/torch/math.py | 5 +++++ keras/ops/math.py | 31 +++++++++++++++++++++++++++++++ keras/ops/math_test.py | 32 ++++++++++++++++++++++++++++++++ 6 files changed, 80 insertions(+) diff --git a/keras/backend/jax/math.py b/keras/backend/jax/math.py index acaa0a4a673..ec6d1a962a5 100644 --- a/keras/backend/jax/math.py +++ b/keras/backend/jax/math.py @@ -254,6 +254,10 @@ def erf(x): return jax.lax.erf(x) +def erfinv(x): + return jax.lax.erf_inv(x) + + def solve(a, b): a = convert_to_tensor(a) b = convert_to_tensor(b) diff --git a/keras/backend/numpy/math.py b/keras/backend/numpy/math.py index fa4abb70d25..0081cdcf32a 100644 --- a/keras/backend/numpy/math.py +++ b/keras/backend/numpy/math.py @@ -308,6 +308,10 @@ def erf(x): return np.array(scipy.special.erf(x)) +def erfinv(x): + return np.array(scipy.special.erfinv(x)) + + def solve(a, b): a = convert_to_tensor(a) b = convert_to_tensor(b) diff --git a/keras/backend/tensorflow/math.py b/keras/backend/tensorflow/math.py index c1806a2fd19..053c797a203 100644 --- a/keras/backend/tensorflow/math.py +++ b/keras/backend/tensorflow/math.py @@ -246,6 +246,10 @@ def erf(x): return tf.math.erf(x) +def erfinv(x): + return tf.math.erfinv(x) + + def solve(a, b): a = convert_to_tensor(a) b = convert_to_tensor(b) diff --git a/keras/backend/torch/math.py b/keras/backend/torch/math.py index d3d4ceb6eae..a949f1fb03b 100644 --- a/keras/backend/torch/math.py +++ b/keras/backend/torch/math.py @@ -415,6 +415,11 @@ def erf(x): return torch.erf(x) +def erfinv(x): + x = convert_to_tensor(x) + return torch.erfinv(x) + + def solve(a, b): a = convert_to_tensor(a) b = convert_to_tensor(b) diff --git a/keras/ops/math.py b/keras/ops/math.py index bff67333bac..d6752bf3966 100644 --- a/keras/ops/math.py +++ b/keras/ops/math.py @@ -968,6 +968,37 @@ def erf(x): return backend.math.erf(x) +class Erfinv(Operation): + def compute_output_spec(self, x): + return KerasTensor(shape=x.shape, dtype=x.dtype) + + def call(self, x): + return backend.math.erfinv(x) + + +@keras_export("keras.ops.erfinv") +def erfinv(x): + """Computes the inverse error function of `x`, element-wise. + + Args: + x: Input tensor. + + Returns: + A tensor with the same dtype as `x`. + + Example: + + >>> x = np.array([-0.5, -0.2, -0.1, 0.0, 0.3]) + >>> keras.ops.erfinv(x) + array([-0.47694, -0.17914, -0.08886, 0. , 0.27246], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Erfinv().symbolic_call(x) + x = backend.convert_to_tensor(x) + return backend.math.erfinv(x) + + + class Solve(Operation): def call(self, a, b): a = backend.convert_to_tensor(a) diff --git a/keras/ops/math_test.py b/keras/ops/math_test.py index 203e26c5c34..1173db3c988 100644 --- a/keras/ops/math_test.py +++ b/keras/ops/math_test.py @@ -887,6 +887,38 @@ def test_erf_operation_edge_cases(self): output_from_edge_erf_op = kmath.erf(edge_values) self.assertAllClose(expected_output, output_from_edge_erf_op, atol=1e-4) + + def test_erfinv_operation_basic(self): + # Sample values for testing + sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) + + # Expected output using numpy's approximation of the error function + expected_output = scipy.special.erfinv(sample_values) + + # Output from the erf operation in keras_core + output_from_erfinv_op = kmath.erfinv(sample_values) + + # Assert that the outputs are close + self.assertAllClose(expected_output, output_from_erfinv_op, atol=1e-4) + + def test_erfinv_operation_dtype(self): + # Test for float32 and float64 data types + for dtype in ("float32", "float64"): + sample_values = np.array( + [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype + ) + expected_output = scipy.special.erfinv(sample_values) + output_from_erfinv_op = kmath.erfinv(sample_values) + self.assertAllClose(expected_output, output_from_erfinv_op, atol=1e-4) + + def test_erfinv_operation_edge_cases(self): + # Test for edge cases + edge_values = np.array([1e5, -1e5, 1e-5, -1e-5], dtype=np.float64) + expected_output = scipy.special.erfinv(edge_values) + output_from_edge_erfinv_op = kmath.erfinv(edge_values) + self.assertAllClose(expected_output, output_from_edge_erfinv_op, atol=1e-4) + + def test_solve(self): x1 = np.array([[1, 2], [4, 5]], dtype="float32") x2 = np.array([[2, 4], [8, 10]], dtype="float32")