Skip to content

Commit

Permalink
Matrix: Add inversion methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mbaudin47 committed Apr 26, 2024
1 parent 567b8e4 commit ee5b184
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 22 deletions.
24 changes: 24 additions & 0 deletions lib/src/Base/Type/MatrixImplementation.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,30 @@ MatrixImplementation MatrixImplementation::solveLinearSystemSquare(const MatrixI
return A.solveLinearSystemSquareInPlace(b);
}

/* Square matrix inverse */
MatrixImplementation MatrixImplementation::inverseSquare() const
{
if (nbColumns_ != nbRows_ ) throw InvalidDimensionException(HERE) << "The matrix has " << nbRows_ << " and " << nbColumns_ << ", expected a square matrix.";
MatrixImplementation identity(nbRows_, nbColumns_);
for(UnsignedInteger i = 0; i < nbRows_; ++i)
identity(i, i) = 1.0;
MatrixImplementation A(*this);
const MatrixImplementation inverseMatrix(A.solveLinearSystemSquareInPlace(identity));
return inverseMatrix;
}

/* Symmetric matrix inverse */
MatrixImplementation MatrixImplementation::inverseSym() const
{
if (nbColumns_ != nbRows_ ) throw InvalidDimensionException(HERE) << "The matrix has " << nbRows_ << " and " << nbColumns_ << ", expected a square matrix.";
MatrixImplementation identity(nbRows_, nbColumns_);
for(UnsignedInteger i = 0; i < nbRows_; ++i)
identity(i, i) = 1.0;
MatrixImplementation A(*this);
const MatrixImplementation inverseMatrix(A.solveLinearSystemSymInPlace(identity));
return inverseMatrix;
}

/* Resolution of a linear system : square matrix */
Point MatrixImplementation::solveLinearSystemSquareInPlace(const Point & b)
{
Expand Down
7 changes: 7 additions & 0 deletions lib/src/Base/Type/SquareMatrix.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,11 @@ Bool SquareMatrix::isDiagonal() const
return true;
}


/** Compute inverse */
SquareMatrix SquareMatrix::inverse() const
{
return getImplementation()->inverseSquare();
}

END_NAMESPACE_OPENTURNS
5 changes: 5 additions & 0 deletions lib/src/Base/Type/SymmetricMatrix.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -348,5 +348,10 @@ Scalar SymmetricMatrix::computeSumElements() const
return getImplementation()->computeSumElements();
}

/** Compute inverse */
SymmetricMatrix SymmetricMatrix::inverse() const
{
return getImplementation()->inverseSym();
}

END_NAMESPACE_OPENTURNS
6 changes: 6 additions & 0 deletions lib/src/Base/Type/openturns/MatrixImplementation.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@ public:
MatrixImplementation solveLinearSystemSquareInPlace(const MatrixImplementation & b);
MatrixImplementation solveLinearSystemSquare(const MatrixImplementation & b) const;

/** Square inverse */
MatrixImplementation inverseSquare() const;

/** Symmetric inverse */
MatrixImplementation inverseSym() const;

/** Resolution of a linear system in case of a triangular matrix */
Point solveLinearSystemTri(const Point & b,
const Bool lower = true,
Expand Down
7 changes: 5 additions & 2 deletions lib/src/Base/Type/openturns/SquareMatrix.hxx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// -*- C++ -*-
/**
* @brief SquareMatrix implements the classical mathematical square matrix
* @brief SquareMatrix implements the square matrix
*
* Copyright 2005-2024 Airbus-EDF-IMACS-ONERA-Phimeca
*
Expand Down Expand Up @@ -35,7 +35,7 @@ class SquareComplexMatrix;
/**
* @class SquareMatrix
*
* SquareMatrix implements the classical mathematical square matrix
* SquareMatrix implements the square matrix
*/

class OT_API SquareMatrix :
Expand Down Expand Up @@ -150,6 +150,9 @@ public:
/** Check if it is diagonal */
Bool isDiagonal() const;

/** Inverse matrix*/
SquareMatrix inverse() const;

protected:

private:
Expand Down
7 changes: 5 additions & 2 deletions lib/src/Base/Type/openturns/SymmetricMatrix.hxx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// -*- C++ -*-
/**
* @brief SymmetricMatrix implements the classical mathematical symmetric matrix
* @brief SymmetricMatrix implements the symmetric matrix
*
* Copyright 2005-2024 Airbus-EDF-IMACS-ONERA-Phimeca
*
Expand Down Expand Up @@ -31,7 +31,7 @@ class IdentityMatrix;
/**
* @class SymmetricMatrix
*
* SymmetricMatrix implements the classical mathematical square matrix
* SymmetricMatrix implements the symmetric matrix
*/

class OT_API SymmetricMatrix :
Expand Down Expand Up @@ -169,6 +169,9 @@ public:
/** Sum all coefficients */
Scalar computeSumElements() const override;

/** Inverse matrix*/
SymmetricMatrix inverse() const;

protected:


Expand Down
23 changes: 23 additions & 0 deletions lib/test/t_SquareMatrix_std.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,29 @@ int main(int, char *[])
<< "squareMatrix1 is empty = " << squareMatrix1.isEmpty() << std::endl
<< "squareMatrix5 is empty = " << squareMatrix5.isEmpty() << std::endl;

/* Check inverse() */
SquareMatrix squareMatrix6(3);
squareMatrix6(0, 0) = 1.0;
squareMatrix6(0, 1) = 2.0;
squareMatrix6(0, 2) = 3.0;
squareMatrix6(1, 0) = 3.0;
squareMatrix6(1, 1) = 2.0;
squareMatrix6(1, 2) = 1.0;
squareMatrix6(2, 0) = 2.0;
squareMatrix6(2, 1) = 1.0;
squareMatrix6(2, 2) = 3.0;
SquareMatrix squareMatrix7(squareMatrix6.inverse());
SquareMatrix inverseReference(3);
inverseReference(0, 0) = -5.0 / 12.0;
inverseReference(0, 1) = 3.0 / 12.0;
inverseReference(0, 2) = 4.0 / 12.0;
inverseReference(1, 0) = 7.0 / 12.0;
inverseReference(1, 1) = 3.0 / 12.0;
inverseReference(1, 2) = -8.0 / 12.0;
inverseReference(2, 0) = 1.0 / 12.0;
inverseReference(2, 1) = -3.0 / 12.0;
inverseReference(2, 2) = 4.0 / 12.0;
assert_almost_equal(squareMatrix7, inverseReference);
}
catch (TestFailed & ex)
{
Expand Down
24 changes: 24 additions & 0 deletions lib/test/t_SymmetricMatrix_std.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,30 @@ int main(int, char *[])
M3(0, 2) = 7.1;
M3(1, 2) = 9.0;
fullprint << "SM * M3 = " << SM * M3 << std::endl;

/* Check inverse() */
SymmetricMatrix symMatrix6(3);
symMatrix6(0, 0) = 4.0;
symMatrix6(0, 1) = 2.0;
symMatrix6(0, 2) = 1.0;
symMatrix6(1, 0) = 2.0;
symMatrix6(1, 1) = 5.0;
symMatrix6(1, 2) = 3.0;
symMatrix6(2, 0) = 1.0;
symMatrix6(2, 1) = 3.0;
symMatrix6(2, 2) = 6.0;
SymmetricMatrix symMatrix7(symMatrix6.inverse());
SymmetricMatrix inverseReference(3);
inverseReference(0, 0) = 21.0 / 67.0;
inverseReference(0, 1) = -9.0 / 67.0;
inverseReference(0, 2) = 1.0 / 67.0;
inverseReference(1, 0) = -9.0 / 67.0;
inverseReference(1, 1) = 23.0 / 67.0;
inverseReference(1, 2) = -10.0 / 67.0;
inverseReference(2, 0) = 1.0 / 67.0;
inverseReference(2, 1) = -10.0 / 67.0;
inverseReference(2, 2) = 16.0 / 67.0;
assert_almost_equal(symMatrix7, inverseReference);
}
catch (TestFailed & ex)
{
Expand Down
35 changes: 19 additions & 16 deletions python/src/Matrix_doc.i.in
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,14 @@ The economic QR factorization of a rectangular matrix :math:`\mat{M}` with
.. math::

\mat{M} = \mat{Q} \mat{R}
= \mat{Q} \begin{bmatrix} \mat{R_1} \\ \mat{0} \end{bmatrix}
= \begin{bmatrix} \mat{Q_1}, \mat{Q_2} \end{bmatrix}
\begin{bmatrix} \mat{R_1} \\ \mat{0} \end{bmatrix}
= \mat{Q_1} \mat{R_1}

where :math:`\mat{R_1}` is an :math:`n_c \times n_c` upper triangular matrix,
:math:`\mat{Q_1}` is :math:`n_r \times n_c`, :math:`\mat{Q_2}` is
:math:`n_r \times (n_r - n_c)`, and :math:`\mat{Q_1}` and :math:`\mat{Q_2}`
= \mat{Q} \begin{bmatrix} \mat{R}_1 \\ \mat{0} \end{bmatrix}
= \begin{bmatrix} \mat{Q}_1, \mat{Q}_2 \end{bmatrix}
\begin{bmatrix} \mat{R}_1 \\ \mat{0} \end{bmatrix}
= \mat{Q}_1 \mat{R}_1

where :math:`\mat{R}_1` is an :math:`n_c \times n_c` upper triangular matrix,
:math:`\mat{Q}_1` is :math:`n_r \times n_c`, :math:`\mat{Q}_2` is
:math:`n_r \times (n_r - n_c)`, and :math:`\mat{Q}_1` and :math:`\mat{Q}_2`
both have orthogonal columns.

Parameters
Expand Down Expand Up @@ -118,7 +118,7 @@ least-square sense because it implies solving a (simple) triangular system:
.. math::

\vect{\hat{x}} = \arg\min\limits_{\vect{x} \in \Rset^{n_r}} \|\mat{M} \vect{x} - \vect{b}\|
= \mat{R_1}^{-1} (\Tr{\mat{Q_1}} \vect{b})
= \mat{R}_1^{-1} (\Tr{\mat{Q}_1} \vect{b})

This uses LAPACK's `DGEQRF <http://www.netlib.org/lapack/lapack-3.1.1/html/dgeqrf.f.html>`_
and `DORGQR <http://www.netlib.org/lapack/lapack-3.1.1/html/dorgqr.f.html>`_.
Expand Down Expand Up @@ -423,8 +423,8 @@ MMT : :class:`~openturns.Matrix`

Notes
-----
When transposed is set to `True`, the method computes :math:`cM^t \times \cM`.
Otherwise it computes :math:`\cM \ times \cM^t`
When `transposed` is `True`, compute :math:`\Tr{M} M`.
Otherwise, compute :math:`M \Tr{M}`.

Examples
--------
Expand Down Expand Up @@ -475,12 +475,15 @@ C : :class:`~openturns.Matrix`
Notes
-----

The matrix :math:`\cC` resulting from the Hadamard product of the matrices
:math:`\cA` and :`\cB` is as follows:
The matrix :math:`\mat{C} \in \Rset^{m \times n}` resulting from the Hadamard product (
also known as the elementwise product) of the matrices
:math:`\mat{A} \in \Rset^{m \times n}` and :math:`\mat{B} \in \Rset^{m \times n}` is:

.. math::

\cC_{i,j} = \cA_{i,j} * \cB_{i,j}
c_{i,j} = a_{i,j} b_{i,j}

for any :math:`i = 1, \cdots, m` and :math:`j = 1, \cdots, n`.

Examples
--------
Expand Down Expand Up @@ -508,11 +511,11 @@ sum : a float
Notes
-----

We compute here the sum of elements of the matrix :math:`\cM`, that defines as:
Compute the sum of elements of the matrix :math:`\mat{A} \in \Rset^{m \times n}`:

.. math::

s = \sum_{i=1}^{nRows}\sum_{j=1}^{nColumns} \cM_{i,j}
s = \sum_{i=1}^{m} \sum_{j=1}^{n} a_{i,j}

Examples
--------
Expand Down
21 changes: 20 additions & 1 deletion python/src/SquareMatrix_doc.i.in
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Get or set terms
[[ 1 ]
[ 1 ]]

Create an openturns matrix from a **square** numpy 2d-array (or matrix, or
Create a matrix from a **square** Numpy 2d-array (or matrix, or
2d-list)...

>>> import numpy as np
Expand Down Expand Up @@ -308,3 +308,22 @@ Examples
>>> M = ot.SquareMatrix([[1.0, 2.0], [3.0, 4.0]])
>>> M.computeTrace()
5.0"

// ---------------------------------------------------------------------

%feature("docstring") OT::SquareMatrix::inverse
"Compute the inverse of the matrix.

Returns
-------
inverseMatrix : :class:`~openturns.SquareMatrix`
The inverse of the matrix.

Examples
--------
>>> import openturns as ot
>>> M = ot.SquareMatrix([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0], [2.0, 1.0, 3.0]])
>>> print(12.0 * M.inverse())
[[ -5 3 4 ]
[ 7 3 -8 ]
[ 1 -3 4 ]]"
21 changes: 20 additions & 1 deletion python/src/SymmetricMatrix_doc.i.in
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Get or set terms
[[ 1 ]
[ 2 ]]

Create an openturns matrix from a **symmetric** numpy 2d-array (or matrix, or
Create a matrix from a **symmetric** Numpy 2d-array (or matrix, or
2d-list)...

>>> import numpy as np
Expand Down Expand Up @@ -158,3 +158,22 @@ Examples

%feature("docstring") OT::SymmetricMatrix::checkSymmetry
"Check if the internal representation is really symmetric."

// ---------------------------------------------------------------------

%feature("docstring") OT::SymmetricMatrix::inverse
"Compute the inverse of the matrix.

Returns
-------
inverseMatrix : :class:`~openturns.SymmetricMatrix`
The inverse of the matrix.

Examples
--------
>>> import openturns as ot
>>> M = ot.SymmetricMatrix([[4.0, 2.0, 1.0], [2.0, 5.0, 3.0], [1.0, 3.0, 6.0]])
>>> print(67.0 * M.inverse())
[[ 21 -9 1 ]
[ -9 23 -10 ]
[ 1 -10 16 ]]"
10 changes: 10 additions & 0 deletions python/test/t_SquareMatrix_std.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#! /usr/bin/env python

import openturns as ot
from openturns.testing import assert_almost_equal

ot.TESTPREAMBLE()

Expand Down Expand Up @@ -126,3 +127,12 @@
print("squareMatrix0 is empty = ", squareMatrix0.isEmpty())
print("squareMatrix1 is empty = ", squareMatrix1.isEmpty())
print("squareMatrix5 is empty = ", squareMatrix5.isEmpty())

# Check inverse()
squareMatrix6 = ot.SquareMatrix([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0], [2.0, 1.0, 3.0]])
squareMatrix7 = squareMatrix6.inverse()
inverseReference = ot.SquareMatrix(
[[-5.0, 3.0, 4.0], [7.0, 3.0, -8.0], [1.0, -3.0, 4.0]]
)
inverseReference /= 12.0
assert_almost_equal(squareMatrix7, inverseReference)
8 changes: 8 additions & 0 deletions python/test/t_SymmetricMatrix_std.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#! /usr/bin/env python

import openturns as ot
from openturns.testing import assert_almost_equal

ot.TESTPREAMBLE()

Expand Down Expand Up @@ -127,3 +128,10 @@
print("symmetricMatrix0 is empty = ", symmetricMatrix0.isEmpty())
print("symmetricMatrix1 is empty = ", symmetricMatrix1.isEmpty())
print("symmetricMatrix5 is empty = ", symmetricMatrix5.isEmpty())

# Check inverse()
symmetricMatrix6 = ot.SymmetricMatrix([[4.0, 2.0, 1.0], [2.0, 5.0, 3.0], [1.0, 3.0, 6.0]])
symmetricMatrix7 = symmetricMatrix6.inverse()
inverseReference = ot.SymmetricMatrix([[21.0, -9.0, 1.0], [-9.0, 23.0, -10.0], [1.0, -10.0, 16.0]])
inverseReference /= 67.0
assert_almost_equal(symmetricMatrix7, inverseReference)

0 comments on commit ee5b184

Please sign in to comment.