Skip to content

Commit

Permalink
Add erfinv op (keras-team#19107)
Browse files Browse the repository at this point in the history
* erfinv backends implementations

* erfinv frontend op

* add test

* fix variable names

* naming convention

* fix symbolic call
  • Loading branch information
frazane authored Jan 26, 2024
1 parent a44c051 commit 47bf22d
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 0 deletions.
4 changes: 4 additions & 0 deletions keras/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ def erf(x):
return jax.lax.erf(x)


def erfinv(x):
return jax.lax.erf_inv(x)


def solve(a, b):
a = convert_to_tensor(a)
b = convert_to_tensor(b)
Expand Down
4 changes: 4 additions & 0 deletions keras/backend/numpy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ def erf(x):
return np.array(scipy.special.erf(x))


def erfinv(x):
return np.array(scipy.special.erfinv(x))


def solve(a, b):
a = convert_to_tensor(a)
b = convert_to_tensor(b)
Expand Down
4 changes: 4 additions & 0 deletions keras/backend/tensorflow/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ def erf(x):
return tf.math.erf(x)


def erfinv(x):
return tf.math.erfinv(x)


def solve(a, b):
a = convert_to_tensor(a)
b = convert_to_tensor(b)
Expand Down
5 changes: 5 additions & 0 deletions keras/backend/torch/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,11 @@ def erf(x):
return torch.erf(x)


def erfinv(x):
x = convert_to_tensor(x)
return torch.erfinv(x)


def solve(a, b):
a = convert_to_tensor(a)
b = convert_to_tensor(b)
Expand Down
31 changes: 31 additions & 0 deletions keras/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,37 @@ def erf(x):
return backend.math.erf(x)


class Erfinv(Operation):
def compute_output_spec(self, x):
return KerasTensor(shape=x.shape, dtype=x.dtype)

def call(self, x):
return backend.math.erfinv(x)


@keras_export("keras.ops.erfinv")
def erfinv(x):
"""Computes the inverse error function of `x`, element-wise.
Args:
x: Input tensor.
Returns:
A tensor with the same dtype as `x`.
Example:
>>> x = np.array([-0.5, -0.2, -0.1, 0.0, 0.3])
>>> keras.ops.erfinv(x)
array([-0.47694, -0.17914, -0.08886, 0. , 0.27246], dtype=float32)
"""
if any_symbolic_tensors((x,)):
return Erfinv().symbolic_call(x)
x = backend.convert_to_tensor(x)
return backend.math.erfinv(x)



class Solve(Operation):
def call(self, a, b):
a = backend.convert_to_tensor(a)
Expand Down
32 changes: 32 additions & 0 deletions keras/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,38 @@ def test_erf_operation_edge_cases(self):
output_from_edge_erf_op = kmath.erf(edge_values)
self.assertAllClose(expected_output, output_from_edge_erf_op, atol=1e-4)


def test_erfinv_operation_basic(self):
# Sample values for testing
sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0])

# Expected output using numpy's approximation of the error function
expected_output = scipy.special.erfinv(sample_values)

# Output from the erf operation in keras_core
output_from_erfinv_op = kmath.erfinv(sample_values)

# Assert that the outputs are close
self.assertAllClose(expected_output, output_from_erfinv_op, atol=1e-4)

def test_erfinv_operation_dtype(self):
# Test for float32 and float64 data types
for dtype in ("float32", "float64"):
sample_values = np.array(
[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype
)
expected_output = scipy.special.erfinv(sample_values)
output_from_erfinv_op = kmath.erfinv(sample_values)
self.assertAllClose(expected_output, output_from_erfinv_op, atol=1e-4)

def test_erfinv_operation_edge_cases(self):
# Test for edge cases
edge_values = np.array([1e5, -1e5, 1e-5, -1e-5], dtype=np.float64)
expected_output = scipy.special.erfinv(edge_values)
output_from_edge_erfinv_op = kmath.erfinv(edge_values)
self.assertAllClose(expected_output, output_from_edge_erfinv_op, atol=1e-4)


def test_solve(self):
x1 = np.array([[1, 2], [4, 5]], dtype="float32")
x2 = np.array([[2, 4], [8, 10]], dtype="float32")
Expand Down

0 comments on commit 47bf22d

Please sign in to comment.