Skip to content

Commit

Permalink
Merge branch 'develop' into update_sundials
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Oct 1, 2024
2 parents f5f462a + be02a6e commit 23c8cd3
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 76 deletions.
68 changes: 25 additions & 43 deletions include/amici/sundials_linsol_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,21 @@ class SUNLinSolWrapper {

/**
* @brief Wrap existing SUNLinearSolver
* @param linsol
*
* @param linsol SUNLinSolWrapper takes ownership of `linsol`.
*/
explicit SUNLinSolWrapper(SUNLinearSolver linsol);

/**
* @brief Wrap existing SUNLinearSolver
*
* @param linsol SUNLinSolWrapper takes ownership of `linsol`.
* @param A Matrix
*/
explicit SUNLinSolWrapper(
SUNLinearSolver linsol, SUNMatrixWrapper const& A
);

virtual ~SUNLinSolWrapper();

/**
Expand Down Expand Up @@ -80,26 +91,17 @@ class SUNLinSolWrapper {
/**
* @brief Performs any linear solver setup needed, based on an updated
* system matrix A.
* @param A
*/
void setup(SUNMatrix A) const;

/**
* @brief Performs any linear solver setup needed, based on an updated
* system matrix A.
* @param A
*/
void setup(SUNMatrixWrapper const& A) const;
void setup() const;

/**
* @brief Solves a linear system A*x = b
* @param A
* @param x A template for cloning vectors needed within the solver.
* @param b
* @param tol Tolerance (weighted 2-norm), iterative solvers only
* @return error flag
*/
int Solve(SUNMatrix A, N_Vector x, N_Vector b, realtype tol) const;
int solve(N_Vector x, N_Vector b, realtype tol) const;

/**
* @brief Returns the last error flag encountered within the linear solver
Expand All @@ -119,7 +121,7 @@ class SUNLinSolWrapper {
* @brief Get the matrix A (matrix solvers only).
* @return A
*/
virtual SUNMatrix getMatrix() const;
virtual SUNMatrixWrapper& getMatrix();

protected:
/**
Expand All @@ -131,6 +133,9 @@ class SUNLinSolWrapper {

/** Wrapped solver */
SUNLinearSolver solver_{nullptr};

/** Matrix A for solver. */
SUNMatrixWrapper A_;
};

/**
Expand All @@ -139,12 +144,12 @@ class SUNLinSolWrapper {
class SUNLinSolBand : public SUNLinSolWrapper {
public:
/**
* @brief Create solver using existing matrix A without taking ownership of
* A.
* @brief Create solver using existing matrix A
*
* @param x A template for cloning vectors needed within the solver.
* @param A square matrix
*/
SUNLinSolBand(N_Vector x, SUNMatrix A);
SUNLinSolBand(N_Vector x, SUNMatrixWrapper A);

/**
* @brief Create new band solver and matrix A.
Expand All @@ -153,12 +158,6 @@ class SUNLinSolBand : public SUNLinSolWrapper {
* @param lbw lower bandwidth of band matrix A
*/
SUNLinSolBand(AmiVector const& x, int ubw, int lbw);

SUNMatrix getMatrix() const override;

private:
/** Matrix A for solver, only if created by here. */
SUNMatrixWrapper A_;
};

/**
Expand All @@ -171,12 +170,6 @@ class SUNLinSolDense : public SUNLinSolWrapper {
* @param x A template for cloning vectors needed within the solver.
*/
explicit SUNLinSolDense(AmiVector const& x);

SUNMatrix getMatrix() const override;

private:
/** Matrix A for solver, only if created by here. */
SUNMatrixWrapper A_;
};

/**
Expand All @@ -192,7 +185,7 @@ class SUNLinSolKLU : public SUNLinSolWrapper {
* @param x A template for cloning vectors needed within the solver.
* @param A sparse matrix
*/
SUNLinSolKLU(N_Vector x, SUNMatrix A);
SUNLinSolKLU(N_Vector x, SUNMatrixWrapper A);

/**
* @brief Create KLU solver and matrix to operate on
Expand All @@ -202,11 +195,10 @@ class SUNLinSolKLU : public SUNLinSolWrapper {
* @param ordering
*/
SUNLinSolKLU(
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering
AmiVector const& x, int nnz, int sparsetype,
StateOrdering ordering = StateOrdering::COLAMD
);

SUNMatrix getMatrix() const override;

/**
* @brief Reinitializes memory and flags for a new factorization
* (symbolic and numeric) to be conducted at the next solver setup call.
Expand All @@ -223,10 +215,6 @@ class SUNLinSolKLU : public SUNLinSolWrapper {
* @param ordering
*/
void setOrdering(StateOrdering ordering);

private:
/** Sparse matrix A for solver, only if created by here. */
SUNMatrixWrapper A_;
};

#ifdef SUNDIALS_SUPERLUMT
Expand All @@ -249,7 +237,7 @@ class SUNLinSolSuperLUMT : public SUNLinSolWrapper {
* @param A sparse matrix
* @param numThreads Number of threads to be used by SuperLUMT
*/
SUNLinSolSuperLUMT(N_Vector x, SUNMatrix A, int numThreads);
SUNLinSolSuperLUMT(N_Vector x, SUNMatrixWrapper A, int numThreads);

/**
* @brief Create SuperLUMT solver and matrix to operate on
Expand Down Expand Up @@ -279,18 +267,12 @@ class SUNLinSolSuperLUMT : public SUNLinSolWrapper {
int numThreads
);

SUNMatrix getMatrix() const override;

/**
* @brief Sets the ordering used by SuperLUMT for reducing fill in the
* linear solve.
* @param ordering
*/
void setOrdering(StateOrdering ordering);

private:
/** Sparse matrix A for solver, only if created by here. */
SUNMatrixWrapper A;
};

#endif
Expand Down
1 change: 0 additions & 1 deletion include/amici/vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,6 @@ class AmiVectorArray {
void copy(AmiVectorArray const& other);

private:

/** main data storage */
std::vector<AmiVector> vec_array_;

Expand Down
79 changes: 47 additions & 32 deletions src/sundials_linsol_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,27 @@ namespace amici {
SUNLinSolWrapper::SUNLinSolWrapper(SUNLinearSolver linsol)
: solver_(linsol) {}

SUNLinSolWrapper::SUNLinSolWrapper(
SUNLinearSolver linsol, SUNMatrixWrapper const& A
)
: solver_(linsol)
, A_(A) {}

SUNLinSolWrapper::~SUNLinSolWrapper() {
if (solver_)
SUNLinSolFree(solver_);
}

SUNLinSolWrapper::SUNLinSolWrapper(SUNLinSolWrapper&& other) noexcept {
std::swap(solver_, other.solver_);
std::swap(A_, other.A_);
}

SUNLinSolWrapper& SUNLinSolWrapper::operator=(SUNLinSolWrapper&& other
) noexcept {
std::swap(solver_, other.solver_);
std::swap(A_, other.A_);
return *this;
}

SUNLinearSolver SUNLinSolWrapper::get() const { return solver_; }
Expand All @@ -31,19 +45,14 @@ int SUNLinSolWrapper::initialize() {
return res;
}

void SUNLinSolWrapper::setup(SUNMatrix A) const {
auto res = SUNLinSolSetup(solver_, A);
void SUNLinSolWrapper::setup() const {
auto res = SUNLinSolSetup(solver_, A_.get());
if (res != SUN_SUCCESS)
throw AmiException("Solver setup failed with code %d", res);
}

void SUNLinSolWrapper::setup(SUNMatrixWrapper const& A) const {
return setup(A.get());
}

int SUNLinSolWrapper::Solve(SUNMatrix A, N_Vector x, N_Vector b, realtype tol)
const {
return SUNLinSolSolve(solver_, A, x, b, tol);
int SUNLinSolWrapper::solve(N_Vector x, N_Vector b, realtype tol) const {
return SUNLinSolSolve(solver_, A_.get(), x, b, tol);
}

long SUNLinSolWrapper::getLastFlag() const {
Expand All @@ -54,7 +63,7 @@ int SUNLinSolWrapper::space(long* lenrwLS, long* leniwLS) const {
return SUNLinSolSpace(solver_, lenrwLS, leniwLS);
}

SUNMatrix SUNLinSolWrapper::getMatrix() const { return nullptr; }
SUNMatrixWrapper& SUNLinSolWrapper::getMatrix() { return A_; }

SUNNonLinSolWrapper::SUNNonLinSolWrapper(SUNNonlinearSolver sol)
: solver(sol) {}
Expand Down Expand Up @@ -153,34 +162,34 @@ void SUNNonLinSolWrapper::initialize() {
);
}

SUNLinSolBand::SUNLinSolBand(N_Vector x, SUNMatrix A)
SUNLinSolBand::SUNLinSolBand(N_Vector x, SUNMatrixWrapper A)
: SUNLinSolWrapper(SUNLinSol_Band(x, A, x->sunctx)) {
if (!solver_)
throw AmiException("Failed to create solver.");
}

SUNLinSolBand::SUNLinSolBand(AmiVector const& x, int ubw, int lbw)
: A_(SUNMatrixWrapper(x.getLength(), ubw, lbw, x.get_ctx())) {
: SUNLinSolWrapper(
nullptr, SUNMatrixWrapper(x.getLength(), ubw, lbw, x.get_ctx())
) {
solver_
= SUNLinSol_Band(const_cast<N_Vector>(x.getNVector()), A_, x.get_ctx());
if (!solver_)
throw AmiException("Failed to create solver.");
}

SUNMatrix SUNLinSolBand::getMatrix() const { return A_.get(); }

SUNLinSolDense::SUNLinSolDense(AmiVector const& x)
: A_(SUNMatrixWrapper(x.getLength(), x.getLength(), x.get_ctx())) {
: SUNLinSolWrapper(
nullptr, SUNMatrixWrapper(x.getLength(), x.getLength(), x.get_ctx())
) {
solver_ = SUNLinSol_Dense(
const_cast<N_Vector>(x.getNVector()), A_, x.get_ctx()
);
if (!solver_)
throw AmiException("Failed to create solver.");
}

SUNMatrix SUNLinSolDense::getMatrix() const { return A_.get(); }

SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrix A)
SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrixWrapper A)
: SUNLinSolWrapper(SUNLinSol_KLU(x, A, x->sunctx)) {
if (!solver_)
throw AmiException("Failed to create solver.");
Expand All @@ -189,20 +198,20 @@ SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrix A)
SUNLinSolKLU::SUNLinSolKLU(
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering
)
: A_(SUNMatrixWrapper(
x.getLength(), x.getLength(), nnz, sparsetype, x.get_ctx()
)) {
solver_ = SUNLinSol_KLU(
const_cast<N_Vector>(x.getNVector()), A_, A_.get()->sunctx
);
: SUNLinSolWrapper(
nullptr,
SUNMatrixWrapper(
x.getLength(), x.getLength(), nnz, sparsetype, x.get_ctx()
)
) {
solver_
= SUNLinSol_KLU(const_cast<N_Vector>(x.getNVector()), A_, x.get_ctx());
if (!solver_)
throw AmiException("Failed to create solver.");

setOrdering(ordering);
}

SUNMatrix SUNLinSolKLU::getMatrix() const { return A_.get(); }

void SUNLinSolKLU::reInit(int nnz, int reinit_type) {
int status = SUNLinSol_KLUReInit(solver_, A_, nnz, reinit_type);
if (status != SUN_SUCCESS)
Expand Down Expand Up @@ -422,8 +431,10 @@ int SUNNonLinSolFixedPoint::getSysFn(SUNNonlinSolSysFn* SysFn) const {

#ifdef SUNDIALS_SUPERLUMT

SUNLinSolSuperLUMT::SUNLinSolSuperLUMT(N_Vector x, SUNMatrix A, int numThreads)
: SUNLinSolWrapper(SUNLinSol_SuperLUMT(x, A, numThreads)) {
SUNLinSolSuperLUMT::SUNLinSolSuperLUMT(
N_Vector x, SUNMatrixWrapper A, int numThreads
)
: SUNLinSolWrapper(SUNLinSol_SuperLUMT(x, A, numThreads), A) {
if (!solver)
throw AmiException("Failed to create solver.");
}
Expand All @@ -432,7 +443,10 @@ SUNLinSolSuperLUMT::SUNLinSolSuperLUMT(
AmiVector const& x, int nnz, int sparsetype,
SUNLinSolSuperLUMT::StateOrdering ordering
)
: A(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) {
: SUNLinSolWrapper(
nullptr,
SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)
) {
int numThreads = 1;
if (auto env = std::getenv("AMICI_SUPERLUMT_NUM_THREADS")) {
numThreads = std::max(1, std::stoi(env));
Expand All @@ -449,16 +463,17 @@ SUNLinSolSuperLUMT::SUNLinSolSuperLUMT(
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering,
int numThreads
)
: A(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) {
: SUNLinSolWrapper(
nullptr,
SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)
) {
solver = SUNLinSol_SuperLUMT(x.getNVector(), A.get(), numThreads);
if (!solver)
throw AmiException("Failed to create solver.");

setOrdering(ordering);
}

SUNMatrix SUNLinSolSuperLUMT::getMatrix() const { return A.get(); }

void SUNLinSolSuperLUMT::setOrdering(StateOrdering ordering) {
auto status
= SUNLinSol_SuperLUMTSetOrdering(solver, static_cast<int>(ordering));
Expand Down

0 comments on commit 23c8cd3

Please sign in to comment.