diff --git a/src/probnum/linops/_linear_operator.py b/src/probnum/linops/_linear_operator.py index efe7e29ea..28010fbf0 100644 --- a/src/probnum/linops/_linear_operator.py +++ b/src/probnum/linops/_linear_operator.py @@ -80,6 +80,7 @@ def __init__( self._logabsdet_cache = None self._trace_cache = None + self._lu_cache = None self._cholesky_cache = None # Property inference @@ -707,6 +708,37 @@ def _cholesky(self, lower: bool) -> LinearOperator: ) ) + def _lu_factor(self): + """This is a modified version of the original implementation in SciPy: + + https://github.com/scipy/scipy/blob/v1.7.1/scipy/linalg/decomp_lu.py#L15-L84 + because the SciPy implementation does not raise an exception if the matrix is + singular. + """ + + if self._lu_cache is None: + from scipy.linalg.lapack import ( # pylint: disable=no-name-in-module,import-outside-toplevel + get_lapack_funcs, + ) + + a = np.asarray_chkfinite(self.todense()) + (getrf,) = get_lapack_funcs(("getrf",), (a,)) + lu, piv, info = getrf(a, overwrite_a=False) + + if info < 0: + raise ValueError( + f"illegal value in argument {-info} of internal getrf (lu_factor)" + ) + + if info > 0: + raise np.linalg.LinAlgError( + f"Diagonal number {info} is exactly zero. Singular matrix." + ) + + self._lu_cache = lu, piv + + return self._lu_cache + #################################################################################### # Unary Arithmetic #################################################################################### @@ -810,6 +842,9 @@ def inv(self) -> "LinearOperator": except NotImplementedError: pass + # This does not need caching, since the `_InverseLinearOperator` only accesses + # quantities (particularly matrix decompositions), which are cached inside the + # original `LinearOperator`. return _InverseLinearOperator(self) def symmetrize(self) -> LinearOperator: @@ -1048,17 +1083,10 @@ def broadcast_matmat( """Broadcasting for a (implicitly defined) matrix-matrix product. Convenience function / decorator to broadcast the definition of a matrix-matrix - product to vectors. This can be used to easily construct a new linear operator - only from a matrix-matrix product. + product to stacks of matrices. This can be used to easily construct a new linear + operator only from a matrix-matrix product. """ - - def _matmul(x: np.ndarray) -> np.ndarray: - if x.ndim == 2: - return matmat(x) - - return _apply_to_matrix_stack(matmat, x) - - return _matmul + return np.vectorize(matmat, signature="(n,k)->(m,k)") @property def _inexact_dtype(self) -> np.dtype: @@ -1068,27 +1096,6 @@ def _inexact_dtype(self) -> np.dtype: return np.double -def _apply_to_matrix_stack( - mat_fn: Callable[[np.ndarray], np.ndarray], x: np.ndarray -) -> np.ndarray: - idcs = np.ndindex(x.shape[:-2]) - - # Shape and dtype inference - idx0 = next(idcs) - y0 = mat_fn(x[idx0]) - - # Result buffer - y = np.empty(x.shape[:-2] + y0.shape, dtype=y0.dtype) - - # Fill buffer - y[idx0] = y0 - - for idx in idcs: - y[idx] = mat_fn(x[idx]) - - return y - - def _call_if_implemented(method: Optional[callable]) -> callable: if method is not None: return method @@ -1312,17 +1319,20 @@ def __init__(self, linop: LinearOperator): self._linop = linop - self.__factorization = None - self._cho_solve = False - - tmatmul = LinearOperator.broadcast_matmat(self._tmatmat) + solve = np.vectorize( + self._solve, + excluded=("trans",), + signature="(n, k)->(n, k)", + ) super().__init__( shape=self._linop.shape, dtype=self._linop._inexact_dtype, - matmul=LinearOperator.broadcast_matmat(self._matmat), - rmatmul=lambda x: tmatmul(x[..., np.newaxis])[..., 0], - transpose=lambda: TransposedLinearOperator(self, matmul=tmatmul), + matmul=solve, + transpose=lambda: TransposedLinearOperator( + self, + matmul=lambda x: solve(x, trans=True), + ), inverse=lambda: self._linop, det=lambda: 1 / self._linop.det(), logabsdet=lambda: -self._linop.logabsdet(), @@ -1335,65 +1345,45 @@ def __init__(self, linop: LinearOperator): def __repr__(self) -> str: return f"Inverse of {self._linop}" - @property - def factorization(self): - if self.__factorization is None: - try: - self.__factorization = ( - self._linop.cholesky(lower=True).T.todense(), - False, - ) - self._cho_solve = True - except np.linalg.LinAlgError: - self.__factorization = _InverseLinearOperator._lu_factor( - self._linop.todense(cache=False) - ) - - return self.__factorization - - def _matmat(self, x: np.ndarray) -> np.ndarray: - factorization = self.factorization # Precompute, so that _cho_solve will be set + def _solve(self, x: np.ndarray, trans: bool = False) -> np.ndarray: + """Solve :math:`A Y = X` for Y, where either :code:`A = self._linop` or + :code:`A = self._linop.T`, depending on the value of :code:`trans`. - if self._cho_solve: - return scipy.linalg.cho_solve(factorization, x, overwrite_b=False) - - return scipy.linalg.lu_solve(factorization, x, trans=0, overwrite_b=False) - - def _tmatmat(self, x: np.ndarray) -> np.ndarray: - factorization = self.factorization # Precompute, so that _cho_solve will be set - - if self._cho_solve: - return scipy.linalg.cho_solve(factorization, x.T, overwrite_b=False) - - return scipy.linalg.lu_solve(factorization, x, trans=1, overwrite_b=False) - - @staticmethod - def _lu_factor(a): - """This is a modified version of the original implementation in SciPy: + Parameters + ---------- + x + :code:`shape=(N,K)` -- + The right-hand sides :math:`X` of the linear systems, where + :code:`A.shape == (N, N)`. + trans + If :code:`False`, then :code:`A = self._linop`. + Otherwise :code:`A = self._linop.T`. - https://github.com/scipy/scipy/blob/v1.7.1/scipy/linalg/decomp_lu.py#L15-L84 - because for some reason, the SciPy implementation does not raise an exception - if the matrix is singular. + Returns + ------- + sol + The solutions :math:`A^{-1} X` of the linear systems. """ - from scipy.linalg.lapack import ( # pylint: disable=no-name-in-module,import-outside-toplevel - get_lapack_funcs, - ) - - a = np.asarray_chkfinite(a) - (getrf,) = get_lapack_funcs(("getrf",), (a,)) - lu, piv, info = getrf(a, overwrite_a=False) - - if info < 0: - raise ValueError( - f"illegal value in argument {-info} of internal getrf (lu_factor)" - ) + assert x.ndim == 2 - if info > 0: - raise np.linalg.LinAlgError( - f"Diagonal number {info} is exactly zero. Singular matrix." - ) - - return lu, piv + if self._linop.is_symmetric: + if self._linop.is_positive_definite is not False: + try: + # A @ x = A^T @ x, since A is symmetric + return scipy.linalg.cho_solve( + (self._linop.cholesky(lower=False).todense(), False), + x, + overwrite_b=False, + ) + except np.linalg.LinAlgError: + pass + + return scipy.linalg.lu_solve( + self._linop._lu_factor(), + x, + trans=1 if trans else 0, + overwrite_b=False, + ) class _TypeCastLinearOperator(LambdaLinearOperator): diff --git a/tests/test_linops/test_linops_cases/linear_operator_cases.py b/tests/test_linops/test_linops_cases/linear_operator_cases.py index a79c4ac99..ef5815d9c 100644 --- a/tests/test_linops/test_linops_cases/linear_operator_cases.py +++ b/tests/test_linops/test_linops_cases/linear_operator_cases.py @@ -91,3 +91,20 @@ def case_sparse_matrix_singular( ) return pn.linops.Matrix(matrix), matrix.toarray() + + +@pytest.mark.parametrize("rng", [np.random.default_rng(422)]) +def case_inverse( + rng: np.random.Generator, +) -> Tuple[pn.linops.LinearOperator, np.ndarray]: + N = 21 + + v = rng.uniform(0.2, 0.5, N) + + linop = pn.linops.LambdaLinearOperator( + shape=(N, N), + dtype=np.double, + matmul=lambda x: 2.0 * x + v[:, None] @ (v[None, :] @ x), + ) + + return linop.inv(), linop.inv().todense()