Skip to content

Commit

Permalink
Caching of the LU decomposition computed by _InverseLinearOperator (#…
Browse files Browse the repository at this point in the history
…737)

* Cache LU decomposition in `LinearOperator`

* Add test case for `_InverseLinearOperator`

* Incorporate review comments

* `broadcast_matmat` -> `np.vectorize`
  • Loading branch information
marvinpfoertner authored Nov 9, 2022
1 parent ce4bbba commit 6fadc00
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 93 deletions.
176 changes: 83 additions & 93 deletions src/probnum/linops/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
self._logabsdet_cache = None
self._trace_cache = None

self._lu_cache = None
self._cholesky_cache = None

# Property inference
Expand Down Expand Up @@ -707,6 +708,37 @@ def _cholesky(self, lower: bool) -> LinearOperator:
)
)

def _lu_factor(self):
"""This is a modified version of the original implementation in SciPy:
https://github.com/scipy/scipy/blob/v1.7.1/scipy/linalg/decomp_lu.py#L15-L84
because the SciPy implementation does not raise an exception if the matrix is
singular.
"""

if self._lu_cache is None:
from scipy.linalg.lapack import ( # pylint: disable=no-name-in-module,import-outside-toplevel
get_lapack_funcs,
)

a = np.asarray_chkfinite(self.todense())
(getrf,) = get_lapack_funcs(("getrf",), (a,))
lu, piv, info = getrf(a, overwrite_a=False)

if info < 0:
raise ValueError(
f"illegal value in argument {-info} of internal getrf (lu_factor)"
)

if info > 0:
raise np.linalg.LinAlgError(
f"Diagonal number {info} is exactly zero. Singular matrix."
)

self._lu_cache = lu, piv

return self._lu_cache

####################################################################################
# Unary Arithmetic
####################################################################################
Expand Down Expand Up @@ -810,6 +842,9 @@ def inv(self) -> "LinearOperator":
except NotImplementedError:
pass

# This does not need caching, since the `_InverseLinearOperator` only accesses
# quantities (particularly matrix decompositions), which are cached inside the
# original `LinearOperator`.
return _InverseLinearOperator(self)

def symmetrize(self) -> LinearOperator:
Expand Down Expand Up @@ -1048,17 +1083,10 @@ def broadcast_matmat(
"""Broadcasting for a (implicitly defined) matrix-matrix product.
Convenience function / decorator to broadcast the definition of a matrix-matrix
product to vectors. This can be used to easily construct a new linear operator
only from a matrix-matrix product.
product to stacks of matrices. This can be used to easily construct a new linear
operator only from a matrix-matrix product.
"""

def _matmul(x: np.ndarray) -> np.ndarray:
if x.ndim == 2:
return matmat(x)

return _apply_to_matrix_stack(matmat, x)

return _matmul
return np.vectorize(matmat, signature="(n,k)->(m,k)")

@property
def _inexact_dtype(self) -> np.dtype:
Expand All @@ -1068,27 +1096,6 @@ def _inexact_dtype(self) -> np.dtype:
return np.double


def _apply_to_matrix_stack(
mat_fn: Callable[[np.ndarray], np.ndarray], x: np.ndarray
) -> np.ndarray:
idcs = np.ndindex(x.shape[:-2])

# Shape and dtype inference
idx0 = next(idcs)
y0 = mat_fn(x[idx0])

# Result buffer
y = np.empty(x.shape[:-2] + y0.shape, dtype=y0.dtype)

# Fill buffer
y[idx0] = y0

for idx in idcs:
y[idx] = mat_fn(x[idx])

return y


def _call_if_implemented(method: Optional[callable]) -> callable:
if method is not None:
return method
Expand Down Expand Up @@ -1312,17 +1319,20 @@ def __init__(self, linop: LinearOperator):

self._linop = linop

self.__factorization = None
self._cho_solve = False

tmatmul = LinearOperator.broadcast_matmat(self._tmatmat)
solve = np.vectorize(
self._solve,
excluded=("trans",),
signature="(n, k)->(n, k)",
)

super().__init__(
shape=self._linop.shape,
dtype=self._linop._inexact_dtype,
matmul=LinearOperator.broadcast_matmat(self._matmat),
rmatmul=lambda x: tmatmul(x[..., np.newaxis])[..., 0],
transpose=lambda: TransposedLinearOperator(self, matmul=tmatmul),
matmul=solve,
transpose=lambda: TransposedLinearOperator(
self,
matmul=lambda x: solve(x, trans=True),
),
inverse=lambda: self._linop,
det=lambda: 1 / self._linop.det(),
logabsdet=lambda: -self._linop.logabsdet(),
Expand All @@ -1335,65 +1345,45 @@ def __init__(self, linop: LinearOperator):
def __repr__(self) -> str:
return f"Inverse of {self._linop}"

@property
def factorization(self):
if self.__factorization is None:
try:
self.__factorization = (
self._linop.cholesky(lower=True).T.todense(),
False,
)
self._cho_solve = True
except np.linalg.LinAlgError:
self.__factorization = _InverseLinearOperator._lu_factor(
self._linop.todense(cache=False)
)

return self.__factorization

def _matmat(self, x: np.ndarray) -> np.ndarray:
factorization = self.factorization # Precompute, so that _cho_solve will be set
def _solve(self, x: np.ndarray, trans: bool = False) -> np.ndarray:
"""Solve :math:`A Y = X` for Y, where either :code:`A = self._linop` or
:code:`A = self._linop.T`, depending on the value of :code:`trans`.
if self._cho_solve:
return scipy.linalg.cho_solve(factorization, x, overwrite_b=False)

return scipy.linalg.lu_solve(factorization, x, trans=0, overwrite_b=False)

def _tmatmat(self, x: np.ndarray) -> np.ndarray:
factorization = self.factorization # Precompute, so that _cho_solve will be set

if self._cho_solve:
return scipy.linalg.cho_solve(factorization, x.T, overwrite_b=False)

return scipy.linalg.lu_solve(factorization, x, trans=1, overwrite_b=False)

@staticmethod
def _lu_factor(a):
"""This is a modified version of the original implementation in SciPy:
Parameters
----------
x
:code:`shape=(N,K)` --
The right-hand sides :math:`X` of the linear systems, where
:code:`A.shape == (N, N)`.
trans
If :code:`False`, then :code:`A = self._linop`.
Otherwise :code:`A = self._linop.T`.
https://github.com/scipy/scipy/blob/v1.7.1/scipy/linalg/decomp_lu.py#L15-L84
because for some reason, the SciPy implementation does not raise an exception
if the matrix is singular.
Returns
-------
sol
The solutions :math:`A^{-1} X` of the linear systems.
"""
from scipy.linalg.lapack import ( # pylint: disable=no-name-in-module,import-outside-toplevel
get_lapack_funcs,
)

a = np.asarray_chkfinite(a)
(getrf,) = get_lapack_funcs(("getrf",), (a,))
lu, piv, info = getrf(a, overwrite_a=False)

if info < 0:
raise ValueError(
f"illegal value in argument {-info} of internal getrf (lu_factor)"
)
assert x.ndim == 2

if info > 0:
raise np.linalg.LinAlgError(
f"Diagonal number {info} is exactly zero. Singular matrix."
)

return lu, piv
if self._linop.is_symmetric:
if self._linop.is_positive_definite is not False:
try:
# A @ x = A^T @ x, since A is symmetric
return scipy.linalg.cho_solve(
(self._linop.cholesky(lower=False).todense(), False),
x,
overwrite_b=False,
)
except np.linalg.LinAlgError:
pass

return scipy.linalg.lu_solve(
self._linop._lu_factor(),
x,
trans=1 if trans else 0,
overwrite_b=False,
)


class _TypeCastLinearOperator(LambdaLinearOperator):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_linops/test_linops_cases/linear_operator_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,20 @@ def case_sparse_matrix_singular(
)

return pn.linops.Matrix(matrix), matrix.toarray()


@pytest.mark.parametrize("rng", [np.random.default_rng(422)])
def case_inverse(
rng: np.random.Generator,
) -> Tuple[pn.linops.LinearOperator, np.ndarray]:
N = 21

v = rng.uniform(0.2, 0.5, N)

linop = pn.linops.LambdaLinearOperator(
shape=(N, N),
dtype=np.double,
matmul=lambda x: 2.0 * x + v[:, None] @ (v[None, :] @ x),
)

return linop.inv(), linop.inv().todense()

0 comments on commit 6fadc00

Please sign in to comment.