Skip to content

Commit

Permalink
Add ops.math.norm (keras-team#19099)
Browse files Browse the repository at this point in the history
* Add `ops.math.norm`

* Improve test coverage
  • Loading branch information
james77777778 authored Jan 25, 2024
1 parent adacf2c commit 06cde60
Show file tree
Hide file tree
Showing 6 changed files with 317 additions and 0 deletions.
4 changes: 4 additions & 0 deletions keras/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,7 @@ def solve(a, b):
a = convert_to_tensor(a)
b = convert_to_tensor(b)
return jnp.linalg.solve(a, b)


def norm(x, ord=None, axis=None, keepdims=False):
return jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
4 changes: 4 additions & 0 deletions keras/backend/numpy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,7 @@ def solve(a, b):
a = convert_to_tensor(a)
b = convert_to_tensor(b)
return np.linalg.solve(a, b)


def norm(x, ord=None, axis=None, keepdims=False):
return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
124 changes: 124 additions & 0 deletions keras/backend/tensorflow/math.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorflow as tf
from tensorflow.experimental import numpy as tfnp

from keras.backend import standardize_dtype
from keras.backend.tensorflow.core import convert_to_tensor
Expand Down Expand Up @@ -249,3 +250,126 @@ def solve(a, b):
a = convert_to_tensor(a)
b = convert_to_tensor(b)
return tf.linalg.solve(a, b)


def norm(x, ord=None, axis=None, keepdims=False):
x = convert_to_tensor(x)
x_shape = x.shape
ndim = x_shape.rank

if ord is None:
ord = "euclidean"

if axis is None:
axis = tuple(range(ndim))
elif isinstance(axis, int):
axis = (axis,)

axis = axis[0] if len(axis) == 1 else axis
num_axes = 1 if isinstance(axis, int) else len(axis)

# Fast path to utilze `tf.linalg.norm`
if (num_axes == 1 and ord in ("euclidean", 2)) or (
num_axes == 2 and ord in ("euclidean", "fro")
):
return tf.linalg.norm(x, axis=axis, keepdims=keepdims)

# Ref: jax.numpy.linalg.norm
if num_axes == 1 and ord not in ("euclidean", 2, "fro", "nuc"):
if ord == float("-inf"):
return tf.math.reduce_min(
tf.math.abs(x), axis=axis, keepdims=keepdims
)
elif ord == 0:
return tf.math.reduce_sum(
tf.cast(tf.not_equal(x, 0), dtype=x.dtype),
axis=axis,
keepdims=keepdims,
)
elif ord == 1:
return tf.math.reduce_sum(
tf.math.abs(x), axis=axis, keepdims=keepdims
)
elif ord == float("inf"):
return tf.math.reduce_max(
tf.math.abs(x), axis=axis, keepdims=keepdims
)
else:
ord = convert_to_tensor(ord, dtype=x.dtype)
out = tf.math.reduce_sum(
tf.pow(tf.math.abs(x), ord), axis=axis, keepdims=keepdims
)
return tf.pow(out, 1.0 / ord)
elif num_axes == 2 and ord in (
"nuc",
float("-inf"),
-2,
-1,
1,
2,
float("inf"),
):
row_axis, col_axis = axis[0], axis[1]
row_axis = row_axis + ndim if row_axis < 0 else row_axis
col_axis = col_axis + ndim if col_axis < 0 else col_axis
if ord == float("-inf"):
if not keepdims and row_axis > col_axis:
row_axis -= 1
x = tf.math.reduce_min(
tf.reduce_sum(tf.math.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis,
keepdims=keepdims,
)
elif ord == -1:
if not keepdims and col_axis > row_axis:
col_axis -= 1
x = tf.math.reduce_min(
tf.reduce_sum(tf.math.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis,
keepdims=keepdims,
)
elif ord == 1:
if not keepdims and col_axis > row_axis:
col_axis -= 1
x = tf.math.reduce_max(
tf.reduce_sum(tf.math.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis,
keepdims=keepdims,
)
elif ord == float("inf"):
if not keepdims and row_axis > col_axis:
row_axis -= 1
x = tf.math.reduce_max(
tf.reduce_sum(tf.math.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis,
keepdims=keepdims,
)
elif ord in ("nuc", 2, -2):
x = tfnp.moveaxis(x, axis, (-2, -1))
if ord == 2:
x = tf.math.reduce_max(
tf.linalg.svd(x, compute_uv=False), axis=-1
)
elif ord == -2:
x = tf.math.reduce_min(
tf.linalg.svd(x, compute_uv=False), axis=-1
)
else:
x = tf.math.reduce_sum(
tf.linalg.svd(x, compute_uv=False), axis=-1
)
if keepdims:
x = tf.expand_dims(x, axis[0])
x = tf.expand_dims(x, axis[1])
return x

if num_axes == 1:
raise ValueError(
f"Invalid `ord` argument for vector norm. Received: ord={ord}"
)
elif num_axes == 2:
raise ValueError(
f"Invalid `ord` argument for matrix norm. Received: ord={ord}"
)
else:
raise ValueError(f"Invalid axis values. Received: axis={axis}")
5 changes: 5 additions & 0 deletions keras/backend/torch/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,8 @@ def solve(a, b):
a = convert_to_tensor(a)
b = convert_to_tensor(b)
return torch.linalg.solve(a, b)


def norm(x, ord=None, axis=None, keepdims=False):
x = convert_to_tensor(x)
return torch.linalg.norm(x, ord=ord, dim=axis, keepdim=keepdims)
119 changes: 119 additions & 0 deletions keras/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,3 +1001,122 @@ def solve(a, b):
a = backend.convert_to_tensor(a)
b = backend.convert_to_tensor(b)
return backend.math.solve(a, b)


class Norm(Operation):
def __init__(self, ord=None, axis=None, keepdims=False):
super().__init__()
if isinstance(ord, str):
if ord not in ("fro", "nuc"):
raise ValueError(
"Invalid `ord` argument. "
"Expected one of {'fro', 'nuc'} when using string. "
f"Received: ord={ord}"
)
if isinstance(axis, int):
axis = [axis]
self.ord = ord
self.axis = axis
self.keepdims = keepdims

def compute_output_spec(self, x):
output_dtype = backend.standardize_dtype(x.dtype)
if "int" in output_dtype or output_dtype == "bool":
output_dtype = backend.floatx()
if self.axis is None:
axis = tuple(range(len(x.shape)))
else:
axis = self.axis
num_axes = len(axis)
if num_axes == 1 and isinstance(self.ord, str):
raise ValueError(
"Invalid `ord` argument for vector norm. "
f"Received: ord={self.ord}"
)
elif num_axes == 2 and self.ord not in (
None,
"fro",
"nuc",
float("inf"),
float("-inf"),
1,
-1,
2,
-2,
):
raise ValueError(
"Invalid `ord` argument for matrix norm. "
f"Received: ord={self.ord}"
)
return KerasTensor(
reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),
dtype=output_dtype,
)

def call(self, x):
x = backend.convert_to_tensor(x)
return backend.math.norm(
x, ord=self.ord, axis=self.axis, keepdims=self.keepdims
)


@keras_export("keras.ops.norm")
def norm(x, ord=None, axis=None, keepdims=False):
"""Matrix or vector norm.
This function is able to return one of eight different matrix norms, or one
of an infinite number of vector norms (described below), depending on the
value of the `ord` parameter.
Args:
x: Input tensor.
ord: Order of the norm (see table under Notes). The default is `None`.
axis: If `axis` is an integer, it specifies the axis of `x` along which
to compute the vector norms. If `axis` is a 2-tuple, it specifies
the axes that hold 2-D matrices, and the matrix norms of these
matrices are computed.
keepdims: If this is set to `True`, the axes which are reduced are left
in the result as dimensions with size one.
Note:
For values of `ord < 1`, the result is, strictly speaking, not a
mathematical 'norm', but it may still be useful for various numerical
purposes. The following norms can be calculated:
- For matrices:
- `ord=None`: Frobenius norm
- `ord="fro"`: Frobenius norm
- `ord=nuc`: nuclear norm
- `ord=np.inf`: `max(sum(abs(x), axis=1))`
- `ord=-np.inf`: `min(sum(abs(x), axis=1))`
- `ord=0`: not supported
- `ord=1`: `max(sum(abs(x), axis=0))`
- `ord=-1`: `min(sum(abs(x), axis=0))`
- `ord=2`: 2-norm (largest sing. value)
- `ord=-2`: smallest singular value
- other: not supported
- For vectors:
- `ord=None`: 2-norm
- `ord="fro"`: not supported
- `ord=nuc`: not supported
- `ord=np.inf`: `max(abs(x))`
- `ord=-np.inf`: `min(abs(x))`
- `ord=0`: `sum(x != 0)`
- `ord=1`: as below
- `ord=-1`: as below
- `ord=2`: as below
- `ord=-2`: as below
- other: `sum(abs(x)**ord)**(1./ord)`
Returns:
Norm of the matrix or vector(s).
Example:
>>> x = keras.ops.reshape(keras.ops.arange(9, dtype="float32") - 4, (3, 3))
>>> keras.ops.norm(x)
7.7459664
"""
if any_symbolic_tensors((x,)):
return Norm(ord=ord, axis=axis, keepdims=keepdims).symbolic_call(x)
x = backend.convert_to_tensor(x)
return backend.math.norm(x, ord=ord, axis=axis, keepdims=keepdims)
61 changes: 61 additions & 0 deletions keras/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from keras import testing
from keras.backend.common.keras_tensor import KerasTensor
from keras.ops import math as kmath
from keras.testing.test_utils import named_product


def _stft(
Expand Down Expand Up @@ -281,6 +282,16 @@ def test_rsqrt(self):
x = KerasTensor([None, 3])
self.assertEqual(kmath.rsqrt(x).shape, (None, 3))

def test_norm(self):
x = KerasTensor((None, 3))
self.assertEqual(kmath.norm(x).shape, ())

x = KerasTensor((None, 3, 3))
self.assertEqual(kmath.norm(x, axis=1).shape, (None, 3))
self.assertEqual(
kmath.norm(x, axis=1, keepdims=True).shape, (None, 1, 3)
)


class MathOpsStaticShapeTest(testing.TestCase):
@pytest.mark.skipif(
Expand Down Expand Up @@ -425,6 +436,10 @@ def test_solve(self):
outputs = kmath.solve(x1, x2)
self.assertEqual(outputs.shape, (2, 2))

def test_norm(self):
x = KerasTensor((2, 3))
self.assertEqual(kmath.norm(x).shape, ())


class MathOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
@pytest.mark.skipif(
Expand Down Expand Up @@ -879,6 +894,52 @@ def test_solve(self):
expected_result = np.array([[2, 0], [0, 2]], dtype="float32")
self.assertAllClose(output, expected_result)

@parameterized.named_parameters(
named_product(
ord=[None, "fro", "nuc", -np.inf, -2, -1, 0, 1, 2, np.inf, 123],
axis=[None, (0, 1), (0, 2)],
keepdims=[False, True],
)
)
def test_norm_matrices(self, ord, axis, keepdims):
if axis is None:
x = np.random.random((6, 7))
else:
x = np.random.random((5, 6, 7))
if ord in (0, 123):
error = RuntimeError if backend.backend() == "torch" else ValueError
with self.assertRaises(error):
kmath.norm(x, ord=ord, axis=axis, keepdims=keepdims)
return
output = kmath.norm(x, ord=ord, axis=axis, keepdims=keepdims)
expected_result = np.linalg.norm(
x, ord=ord, axis=axis, keepdims=keepdims
)
self.assertAllClose(output, expected_result)

@parameterized.named_parameters(
named_product(
ord=[None, "fro", "nuc", -np.inf, -2, -1, 0, 1, 2, np.inf, 123],
axis=[None, 1, -1],
keepdims=[False, True],
)
)
def test_norm_vectors(self, ord, axis, keepdims):
if axis is None:
x = np.random.random((5,))
else:
x = np.random.random((5, 6))
if ord in ("fro", "nuc"):
error = RuntimeError if backend.backend() == "torch" else ValueError
with self.assertRaises(error):
kmath.norm(x, ord=ord, axis=axis, keepdims=keepdims)
return
output = kmath.norm(x, ord=ord, axis=axis, keepdims=keepdims)
expected_result = np.linalg.norm(
x, ord=ord, axis=axis, keepdims=keepdims
)
self.assertAllClose(output, expected_result)


class QrOpTest(testing.TestCase):
def test_qr_init_mode_reduced(self):
Expand Down

0 comments on commit 06cde60

Please sign in to comment.