Skip to content

Commit

Permalink
Linop diagonal (#820)
Browse files Browse the repository at this point in the history
* Add `LinearOperator.diagonal()`

... which computes the diagonal of the linear operator. Default
implementation multiplies with unit vectors, and subclasses may 
use more efficient implementations, e.g. for Kronecker linear operators.

* Add tests for `LinearOperator.diagonal()`
  • Loading branch information
timweiland authored May 10, 2023
1 parent 555485f commit 65f5604
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 19 deletions.
6 changes: 4 additions & 2 deletions src/probnum/linops/_arithmetic_fallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, linop: LinearOperator, scalar: ScalarLike):
transpose=lambda: self._scalar * self._linop.T,
inverse=self._inv,
trace=lambda: self._scalar * self._linop.trace(),
diagonal=lambda: self._scalar * self._linop.diagonal(),
)

# Matrix properties
Expand Down Expand Up @@ -89,7 +90,6 @@ class SumLinearOperator(LambdaLinearOperator):
"""Sum of linear operators."""

def __init__(self, *summands: LinearOperator):

if not all(summand.shape == summands[0].shape for summand in summands):
raise ValueError("All summands must have the same shape.")

Expand All @@ -113,6 +113,9 @@ def __init__(self, *summands: LinearOperator):
trace=lambda: functools.reduce(
operator.add, (summand.trace() for summand in self._summands)
),
diagonal=lambda: functools.reduce(
operator.add, (summand.diagonal() for summand in self._summands)
),
)

# Matrix properties
Expand Down Expand Up @@ -176,7 +179,6 @@ class ProductLinearOperator(LambdaLinearOperator):
"""(Operator) Product of linear operators."""

def __init__(self, *factors: LinearOperator):

if not all(
lfactor.shape[1] == rfactor.shape[0]
for lfactor, rfactor in zip(factors[:-1], factors[1:])
Expand Down
5 changes: 5 additions & 0 deletions src/probnum/linops/_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def _rank(self) -> np.intp:
return np.sum([block.rank() for block in self.blocks])
return super()._rank()

def _diagonal(self) -> np.ndarray:
if self._all_blocks_square:
return np.concatenate([block.diagonal() for block in self.blocks])
return super()._diagonal()

def _cholesky(self, lower: bool) -> BlockDiagonalMatrix:
if self._all_blocks_square:
return BlockDiagonalMatrix(
Expand Down
9 changes: 5 additions & 4 deletions src/probnum/linops/_kronecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ def _trace(self) -> np.number:

return super()._trace()

def _diagonal(self) -> np.ndarray:
if self.B.is_square:
return np.kron(self.A.diagonal(), self.B.diagonal())
return super()._diagonal()

def _astype(
self, dtype: DTypeLike, order: str, casting: str, copy: bool
) -> "Kronecker":
Expand Down Expand Up @@ -200,7 +205,6 @@ def _matmul_kronecker(self, other: "Kronecker") -> "Kronecker":
def _add_kronecker(
self, other: "Kronecker"
) -> Union[NotImplementedType, "Kronecker"]:

if self.A is other.A or self.A == other.A:
return Kronecker(A=self.A, B=self.B + other.B)

Expand All @@ -212,7 +216,6 @@ def _add_kronecker(
def _sub_kronecker(
self, other: "Kronecker"
) -> Union[NotImplementedType, "Kronecker"]:

if self.A is other.A or self.A == other.A:
return Kronecker(A=self.A, B=self.B - other.B)

Expand Down Expand Up @@ -537,7 +540,6 @@ def _matmul_idkronecker(self, other: "IdentityKronecker") -> "IdentityKronecker"
def _add_idkronecker(
self, other: "IdentityKronecker"
) -> Union[NotImplementedType, "IdentityKronecker"]:

if self.A.shape == other.A.shape:
return IdentityKronecker(num_blocks=self._num_blocks, B=self.B + other.B)

Expand All @@ -546,7 +548,6 @@ def _add_idkronecker(
def _sub_idkronecker(
self, other: "IdentityKronecker"
) -> Union[NotImplementedType, "IdentityKronecker"]:

if self.A.shape == other.A.shape:
return IdentityKronecker(num_blocks=self._num_blocks, B=self.B - other.B)

Expand Down
54 changes: 41 additions & 13 deletions src/probnum/linops/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
self._det_cache = None
self._logabsdet_cache = None
self._trace_cache = None
self._diagonal_cache = None

self._lu_cache = None
self._cholesky_cache = None
Expand Down Expand Up @@ -737,19 +738,7 @@ def _trace(self) -> np.number:
trace : float
Trace of the linear operator.
"""

vec = np.zeros(self.shape[1], dtype=self.dtype)

vec[0] = 1
trace = (self @ vec)[0]
vec[0] = 0

for i in range(1, self.shape[0]):
vec[i] = 1
trace += (self @ vec)[i]
vec[i] = 0

return trace
return np.sum(self.diagonal())

def trace(self) -> np.number:
r"""Trace of the linear operator.
Expand Down Expand Up @@ -777,6 +766,31 @@ def trace(self) -> np.number:

return self._trace_cache

def _diagonal(self) -> np.ndarray:
"""Diagonal of the linear operator.
You may implement this method in a subclass.
"""
D = np.min(self.shape)
diag = np.zeros(D, dtype=self.dtype)
vec = np.zeros(self.shape[1], dtype=self.dtype)

for i in range(D):
vec[i] = 1
diag[i] = (self @ vec)[i]
vec[i] = 0

return diag

def diagonal(self) -> np.ndarray:
"""Diagonal of the linear operator."""
if self._diagonal_cache is None:
self._diagonal_cache = self._diagonal()

self._diagonal_cache.setflags(write=False)

return self._diagonal_cache

####################################################################################
# Matrix Decompositions
####################################################################################
Expand Down Expand Up @@ -1337,6 +1351,7 @@ def __init__(
det: Optional[Callable[[], np.inexact]] = None,
logabsdet: Optional[Callable[[], np.floating]] = None,
trace: Optional[Callable[[], np.number]] = None,
diagonal: Optional[Callable[[], np.ndarray]] = None,
):
super().__init__(shape, dtype)

Expand All @@ -1357,6 +1372,7 @@ def __init__(
self._det_fn = det
self._logabsdet_fn = logabsdet
self._trace_fn = trace
self._diagonal_fn = diagonal

def _matmul(self, x: np.ndarray) -> np.ndarray:
return self._matmul_fn(x)
Expand Down Expand Up @@ -1429,6 +1445,12 @@ def _trace(self) -> np.number:

return self._trace_fn()

def _diagonal(self) -> np.ndarray:
if self._diagonal_fn is None:
return super()._diagonal()

return self._diagonal_fn()


class TransposedLinearOperator(LambdaLinearOperator):
"""Transposition of a linear operator."""
Expand Down Expand Up @@ -1457,6 +1479,7 @@ def __init__(
det=self._linop.det,
logabsdet=self._linop.logabsdet,
trace=self._linop.trace,
diagonal=self._linop.diagonal,
)

def _astype(
Expand Down Expand Up @@ -1561,6 +1584,7 @@ def __init__(
det=lambda: self._linop.det().astype(self._inexact_dtype),
logabsdet=lambda: self._linop.logabsdet().astype(self._inexact_dtype),
trace=lambda: self._linop.trace().astype(dtype),
diagonal=lambda: self._linop.diagonal().astype(dtype),
)

def _astype(
Expand Down Expand Up @@ -1591,20 +1615,23 @@ def __init__(self, A: Union[ArrayLike, scipy.sparse.spmatrix]):
matmul = LinearOperator.broadcast_matmat(lambda x: self.A @ x)
todense = self.A.toarray
trace = lambda: self.A.diagonal().sum()
diagonal = self.A.diagonal
else:
self.A = np.asarray(A)
self.A.setflags(write=False)

matmul = lambda x: self.A @ x
todense = lambda: self.A
trace = lambda: np.trace(self.A)
diagonal = lambda: np.diagonal(self.A)

super().__init__(
self.A.shape,
self.A.dtype,
matmul=matmul,
todense=todense,
trace=trace,
diagonal=diagonal,
)

def _transpose(self) -> "Matrix":
Expand Down Expand Up @@ -1691,6 +1718,7 @@ def __init__(
trace=lambda: probnum.utils.as_numpy_scalar(
self.shape[0], dtype=self.dtype
),
diagonal=lambda: np.ones(shape[0], dtype=self.dtype),
)

# Matrix properties
Expand Down
5 changes: 5 additions & 0 deletions src/probnum/linops/_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ def _cond_isotropic(self, p: Union[None, int, float, str]) -> np.floating:

return np.linalg.cond(self.todense(cache=False), p=p)

def _diagonal(self) -> np.ndarray:
return self.factors

def _cholesky(self, lower: bool = True) -> Scaling:
if self._scalar is not None:
if self._scalar <= 0:
Expand Down Expand Up @@ -347,6 +350,7 @@ def __init__(self, shape, dtype=np.float64):
det = lambda: np.zeros(shape=(), dtype=dtype)

trace = lambda: np.zeros(shape=(), dtype=dtype)
diagonal = lambda: np.zeros(shape=(np.min(shape),), dtype=dtype)

def matmul(x: np.ndarray) -> np.ndarray:
target_shape = list(x.shape)
Expand All @@ -363,6 +367,7 @@ def matmul(x: np.ndarray) -> np.ndarray:
eigvals=eigvals,
det=det,
trace=trace,
diagonal=diagonal,
)

# Matrix properties
Expand Down
12 changes: 12 additions & 0 deletions tests/test_linops/test_linops.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,18 @@ def test_trace(linop: pn.linops.LinearOperator, matrix: np.ndarray):
linop.trace()


@pytest_cases.parametrize_with_cases("linop,matrix", cases=case_modules)
def test_diagonal(linop: pn.linops.LinearOperator, matrix: np.ndarray):
linop_diagonal = linop.diagonal()
matrix_diagonal = np.diagonal(matrix)

assert isinstance(linop_diagonal, np.ndarray)
assert linop_diagonal.shape == matrix_diagonal.shape
assert linop_diagonal.dtype == matrix_diagonal.dtype

np.testing.assert_allclose(linop_diagonal, matrix_diagonal)


@pytest_cases.parametrize_with_cases("linop,matrix", cases=case_modules)
def test_transpose(linop: pn.linops.LinearOperator, matrix: np.ndarray):
matrix_transpose = matrix.transpose()
Expand Down

0 comments on commit 65f5604

Please sign in to comment.