diff --git a/src/probnum/linops/_arithmetic_fallbacks.py b/src/probnum/linops/_arithmetic_fallbacks.py index 52f7c73ec..0b0577134 100644 --- a/src/probnum/linops/_arithmetic_fallbacks.py +++ b/src/probnum/linops/_arithmetic_fallbacks.py @@ -40,6 +40,7 @@ def __init__(self, linop: LinearOperator, scalar: ScalarLike): transpose=lambda: self._scalar * self._linop.T, inverse=self._inv, trace=lambda: self._scalar * self._linop.trace(), + diagonal=lambda: self._scalar * self._linop.diagonal(), ) # Matrix properties @@ -89,7 +90,6 @@ class SumLinearOperator(LambdaLinearOperator): """Sum of linear operators.""" def __init__(self, *summands: LinearOperator): - if not all(summand.shape == summands[0].shape for summand in summands): raise ValueError("All summands must have the same shape.") @@ -113,6 +113,9 @@ def __init__(self, *summands: LinearOperator): trace=lambda: functools.reduce( operator.add, (summand.trace() for summand in self._summands) ), + diagonal=lambda: functools.reduce( + operator.add, (summand.diagonal() for summand in self._summands) + ), ) # Matrix properties @@ -176,7 +179,6 @@ class ProductLinearOperator(LambdaLinearOperator): """(Operator) Product of linear operators.""" def __init__(self, *factors: LinearOperator): - if not all( lfactor.shape[1] == rfactor.shape[0] for lfactor, rfactor in zip(factors[:-1], factors[1:]) diff --git a/src/probnum/linops/_block.py b/src/probnum/linops/_block.py index d51f836b5..3addc77d8 100644 --- a/src/probnum/linops/_block.py +++ b/src/probnum/linops/_block.py @@ -132,6 +132,11 @@ def _rank(self) -> np.intp: return np.sum([block.rank() for block in self.blocks]) return super()._rank() + def _diagonal(self) -> np.ndarray: + if self._all_blocks_square: + return np.concatenate([block.diagonal() for block in self.blocks]) + return super()._diagonal() + def _cholesky(self, lower: bool) -> BlockDiagonalMatrix: if self._all_blocks_square: return BlockDiagonalMatrix( diff --git a/src/probnum/linops/_kronecker.py b/src/probnum/linops/_kronecker.py index 5bdecaf6e..f62598c13 100644 --- a/src/probnum/linops/_kronecker.py +++ b/src/probnum/linops/_kronecker.py @@ -172,6 +172,11 @@ def _trace(self) -> np.number: return super()._trace() + def _diagonal(self) -> np.ndarray: + if self.B.is_square: + return np.kron(self.A.diagonal(), self.B.diagonal()) + return super()._diagonal() + def _astype( self, dtype: DTypeLike, order: str, casting: str, copy: bool ) -> "Kronecker": @@ -200,7 +205,6 @@ def _matmul_kronecker(self, other: "Kronecker") -> "Kronecker": def _add_kronecker( self, other: "Kronecker" ) -> Union[NotImplementedType, "Kronecker"]: - if self.A is other.A or self.A == other.A: return Kronecker(A=self.A, B=self.B + other.B) @@ -212,7 +216,6 @@ def _add_kronecker( def _sub_kronecker( self, other: "Kronecker" ) -> Union[NotImplementedType, "Kronecker"]: - if self.A is other.A or self.A == other.A: return Kronecker(A=self.A, B=self.B - other.B) @@ -537,7 +540,6 @@ def _matmul_idkronecker(self, other: "IdentityKronecker") -> "IdentityKronecker" def _add_idkronecker( self, other: "IdentityKronecker" ) -> Union[NotImplementedType, "IdentityKronecker"]: - if self.A.shape == other.A.shape: return IdentityKronecker(num_blocks=self._num_blocks, B=self.B + other.B) @@ -546,7 +548,6 @@ def _add_idkronecker( def _sub_idkronecker( self, other: "IdentityKronecker" ) -> Union[NotImplementedType, "IdentityKronecker"]: - if self.A.shape == other.A.shape: return IdentityKronecker(num_blocks=self._num_blocks, B=self.B - other.B) diff --git a/src/probnum/linops/_linear_operator.py b/src/probnum/linops/_linear_operator.py index 5a4b3ca81..81854bd4e 100644 --- a/src/probnum/linops/_linear_operator.py +++ b/src/probnum/linops/_linear_operator.py @@ -109,6 +109,7 @@ def __init__( self._det_cache = None self._logabsdet_cache = None self._trace_cache = None + self._diagonal_cache = None self._lu_cache = None self._cholesky_cache = None @@ -737,19 +738,7 @@ def _trace(self) -> np.number: trace : float Trace of the linear operator. """ - - vec = np.zeros(self.shape[1], dtype=self.dtype) - - vec[0] = 1 - trace = (self @ vec)[0] - vec[0] = 0 - - for i in range(1, self.shape[0]): - vec[i] = 1 - trace += (self @ vec)[i] - vec[i] = 0 - - return trace + return np.sum(self.diagonal()) def trace(self) -> np.number: r"""Trace of the linear operator. @@ -777,6 +766,31 @@ def trace(self) -> np.number: return self._trace_cache + def _diagonal(self) -> np.ndarray: + """Diagonal of the linear operator. + + You may implement this method in a subclass. + """ + D = np.min(self.shape) + diag = np.zeros(D, dtype=self.dtype) + vec = np.zeros(self.shape[1], dtype=self.dtype) + + for i in range(D): + vec[i] = 1 + diag[i] = (self @ vec)[i] + vec[i] = 0 + + return diag + + def diagonal(self) -> np.ndarray: + """Diagonal of the linear operator.""" + if self._diagonal_cache is None: + self._diagonal_cache = self._diagonal() + + self._diagonal_cache.setflags(write=False) + + return self._diagonal_cache + #################################################################################### # Matrix Decompositions #################################################################################### @@ -1337,6 +1351,7 @@ def __init__( det: Optional[Callable[[], np.inexact]] = None, logabsdet: Optional[Callable[[], np.floating]] = None, trace: Optional[Callable[[], np.number]] = None, + diagonal: Optional[Callable[[], np.ndarray]] = None, ): super().__init__(shape, dtype) @@ -1357,6 +1372,7 @@ def __init__( self._det_fn = det self._logabsdet_fn = logabsdet self._trace_fn = trace + self._diagonal_fn = diagonal def _matmul(self, x: np.ndarray) -> np.ndarray: return self._matmul_fn(x) @@ -1429,6 +1445,12 @@ def _trace(self) -> np.number: return self._trace_fn() + def _diagonal(self) -> np.ndarray: + if self._diagonal_fn is None: + return super()._diagonal() + + return self._diagonal_fn() + class TransposedLinearOperator(LambdaLinearOperator): """Transposition of a linear operator.""" @@ -1457,6 +1479,7 @@ def __init__( det=self._linop.det, logabsdet=self._linop.logabsdet, trace=self._linop.trace, + diagonal=self._linop.diagonal, ) def _astype( @@ -1561,6 +1584,7 @@ def __init__( det=lambda: self._linop.det().astype(self._inexact_dtype), logabsdet=lambda: self._linop.logabsdet().astype(self._inexact_dtype), trace=lambda: self._linop.trace().astype(dtype), + diagonal=lambda: self._linop.diagonal().astype(dtype), ) def _astype( @@ -1591,6 +1615,7 @@ def __init__(self, A: Union[ArrayLike, scipy.sparse.spmatrix]): matmul = LinearOperator.broadcast_matmat(lambda x: self.A @ x) todense = self.A.toarray trace = lambda: self.A.diagonal().sum() + diagonal = self.A.diagonal else: self.A = np.asarray(A) self.A.setflags(write=False) @@ -1598,6 +1623,7 @@ def __init__(self, A: Union[ArrayLike, scipy.sparse.spmatrix]): matmul = lambda x: self.A @ x todense = lambda: self.A trace = lambda: np.trace(self.A) + diagonal = lambda: np.diagonal(self.A) super().__init__( self.A.shape, @@ -1605,6 +1631,7 @@ def __init__(self, A: Union[ArrayLike, scipy.sparse.spmatrix]): matmul=matmul, todense=todense, trace=trace, + diagonal=diagonal, ) def _transpose(self) -> "Matrix": @@ -1691,6 +1718,7 @@ def __init__( trace=lambda: probnum.utils.as_numpy_scalar( self.shape[0], dtype=self.dtype ), + diagonal=lambda: np.ones(shape[0], dtype=self.dtype), ) # Matrix properties diff --git a/src/probnum/linops/_scaling.py b/src/probnum/linops/_scaling.py index 52ddfcd6f..038b5665f 100644 --- a/src/probnum/linops/_scaling.py +++ b/src/probnum/linops/_scaling.py @@ -312,6 +312,9 @@ def _cond_isotropic(self, p: Union[None, int, float, str]) -> np.floating: return np.linalg.cond(self.todense(cache=False), p=p) + def _diagonal(self) -> np.ndarray: + return self.factors + def _cholesky(self, lower: bool = True) -> Scaling: if self._scalar is not None: if self._scalar <= 0: @@ -347,6 +350,7 @@ def __init__(self, shape, dtype=np.float64): det = lambda: np.zeros(shape=(), dtype=dtype) trace = lambda: np.zeros(shape=(), dtype=dtype) + diagonal = lambda: np.zeros(shape=(np.min(shape),), dtype=dtype) def matmul(x: np.ndarray) -> np.ndarray: target_shape = list(x.shape) @@ -363,6 +367,7 @@ def matmul(x: np.ndarray) -> np.ndarray: eigvals=eigvals, det=det, trace=trace, + diagonal=diagonal, ) # Matrix properties diff --git a/tests/test_linops/test_linops.py b/tests/test_linops/test_linops.py index 530c8b0f3..0e0ab6cf5 100644 --- a/tests/test_linops/test_linops.py +++ b/tests/test_linops/test_linops.py @@ -381,6 +381,18 @@ def test_trace(linop: pn.linops.LinearOperator, matrix: np.ndarray): linop.trace() +@pytest_cases.parametrize_with_cases("linop,matrix", cases=case_modules) +def test_diagonal(linop: pn.linops.LinearOperator, matrix: np.ndarray): + linop_diagonal = linop.diagonal() + matrix_diagonal = np.diagonal(matrix) + + assert isinstance(linop_diagonal, np.ndarray) + assert linop_diagonal.shape == matrix_diagonal.shape + assert linop_diagonal.dtype == matrix_diagonal.dtype + + np.testing.assert_allclose(linop_diagonal, matrix_diagonal) + + @pytest_cases.parametrize_with_cases("linop,matrix", cases=case_modules) def test_transpose(linop: pn.linops.LinearOperator, matrix: np.ndarray): matrix_transpose = matrix.transpose()