diff --git a/include/amici/sundials_linsol_wrapper.h b/include/amici/sundials_linsol_wrapper.h index 9009115489..bc280039cd 100644 --- a/include/amici/sundials_linsol_wrapper.h +++ b/include/amici/sundials_linsol_wrapper.h @@ -33,10 +33,21 @@ class SUNLinSolWrapper { /** * @brief Wrap existing SUNLinearSolver - * @param linsol + * + * @param linsol SUNLinSolWrapper takes ownership of `linsol`. */ explicit SUNLinSolWrapper(SUNLinearSolver linsol); + /** + * @brief Wrap existing SUNLinearSolver + * + * @param linsol SUNLinSolWrapper takes ownership of `linsol`. + * @param A Matrix + */ + explicit SUNLinSolWrapper( + SUNLinearSolver linsol, SUNMatrixWrapper const& A + ); + virtual ~SUNLinSolWrapper(); /** @@ -80,26 +91,17 @@ class SUNLinSolWrapper { /** * @brief Performs any linear solver setup needed, based on an updated * system matrix A. - * @param A */ - void setup(SUNMatrix A) const; - - /** - * @brief Performs any linear solver setup needed, based on an updated - * system matrix A. - * @param A - */ - void setup(SUNMatrixWrapper const& A) const; + void setup() const; /** * @brief Solves a linear system A*x = b - * @param A * @param x A template for cloning vectors needed within the solver. * @param b * @param tol Tolerance (weighted 2-norm), iterative solvers only * @return error flag */ - int Solve(SUNMatrix A, N_Vector x, N_Vector b, realtype tol) const; + int solve(N_Vector x, N_Vector b, realtype tol) const; /** * @brief Returns the last error flag encountered within the linear solver @@ -119,7 +121,7 @@ class SUNLinSolWrapper { * @brief Get the matrix A (matrix solvers only). * @return A */ - virtual SUNMatrix getMatrix() const; + virtual SUNMatrixWrapper& getMatrix(); protected: /** @@ -131,6 +133,9 @@ class SUNLinSolWrapper { /** Wrapped solver */ SUNLinearSolver solver_{nullptr}; + + /** Matrix A for solver. */ + SUNMatrixWrapper A_; }; /** @@ -139,12 +144,12 @@ class SUNLinSolWrapper { class SUNLinSolBand : public SUNLinSolWrapper { public: /** - * @brief Create solver using existing matrix A without taking ownership of - * A. + * @brief Create solver using existing matrix A + * * @param x A template for cloning vectors needed within the solver. * @param A square matrix */ - SUNLinSolBand(N_Vector x, SUNMatrix A); + SUNLinSolBand(N_Vector x, SUNMatrixWrapper A); /** * @brief Create new band solver and matrix A. @@ -153,12 +158,6 @@ class SUNLinSolBand : public SUNLinSolWrapper { * @param lbw lower bandwidth of band matrix A */ SUNLinSolBand(AmiVector const& x, int ubw, int lbw); - - SUNMatrix getMatrix() const override; - - private: - /** Matrix A for solver, only if created by here. */ - SUNMatrixWrapper A_; }; /** @@ -171,12 +170,6 @@ class SUNLinSolDense : public SUNLinSolWrapper { * @param x A template for cloning vectors needed within the solver. */ explicit SUNLinSolDense(AmiVector const& x); - - SUNMatrix getMatrix() const override; - - private: - /** Matrix A for solver, only if created by here. */ - SUNMatrixWrapper A_; }; /** @@ -192,7 +185,7 @@ class SUNLinSolKLU : public SUNLinSolWrapper { * @param x A template for cloning vectors needed within the solver. * @param A sparse matrix */ - SUNLinSolKLU(N_Vector x, SUNMatrix A); + SUNLinSolKLU(N_Vector x, SUNMatrixWrapper A); /** * @brief Create KLU solver and matrix to operate on @@ -202,11 +195,10 @@ class SUNLinSolKLU : public SUNLinSolWrapper { * @param ordering */ SUNLinSolKLU( - AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering + AmiVector const& x, int nnz, int sparsetype, + StateOrdering ordering = StateOrdering::COLAMD ); - SUNMatrix getMatrix() const override; - /** * @brief Reinitializes memory and flags for a new factorization * (symbolic and numeric) to be conducted at the next solver setup call. @@ -223,10 +215,6 @@ class SUNLinSolKLU : public SUNLinSolWrapper { * @param ordering */ void setOrdering(StateOrdering ordering); - - private: - /** Sparse matrix A for solver, only if created by here. */ - SUNMatrixWrapper A_; }; #ifdef SUNDIALS_SUPERLUMT @@ -249,7 +237,7 @@ class SUNLinSolSuperLUMT : public SUNLinSolWrapper { * @param A sparse matrix * @param numThreads Number of threads to be used by SuperLUMT */ - SUNLinSolSuperLUMT(N_Vector x, SUNMatrix A, int numThreads); + SUNLinSolSuperLUMT(N_Vector x, SUNMatrixWrapper A, int numThreads); /** * @brief Create SuperLUMT solver and matrix to operate on @@ -279,18 +267,12 @@ class SUNLinSolSuperLUMT : public SUNLinSolWrapper { int numThreads ); - SUNMatrix getMatrix() const override; - /** * @brief Sets the ordering used by SuperLUMT for reducing fill in the * linear solve. * @param ordering */ void setOrdering(StateOrdering ordering); - - private: - /** Sparse matrix A for solver, only if created by here. */ - SUNMatrixWrapper A; }; #endif diff --git a/include/amici/vector.h b/include/amici/vector.h index da4682e654..72da7a9230 100644 --- a/include/amici/vector.h +++ b/include/amici/vector.h @@ -415,7 +415,6 @@ class AmiVectorArray { void copy(AmiVectorArray const& other); private: - /** main data storage */ std::vector vec_array_; diff --git a/src/sundials_linsol_wrapper.cpp b/src/sundials_linsol_wrapper.cpp index 9118c4bd63..d1b01c8546 100644 --- a/src/sundials_linsol_wrapper.cpp +++ b/src/sundials_linsol_wrapper.cpp @@ -9,6 +9,12 @@ namespace amici { SUNLinSolWrapper::SUNLinSolWrapper(SUNLinearSolver linsol) : solver_(linsol) {} +SUNLinSolWrapper::SUNLinSolWrapper( + SUNLinearSolver linsol, SUNMatrixWrapper const& A +) + : solver_(linsol) + , A_(A) {} + SUNLinSolWrapper::~SUNLinSolWrapper() { if (solver_) SUNLinSolFree(solver_); @@ -16,6 +22,14 @@ SUNLinSolWrapper::~SUNLinSolWrapper() { SUNLinSolWrapper::SUNLinSolWrapper(SUNLinSolWrapper&& other) noexcept { std::swap(solver_, other.solver_); + std::swap(A_, other.A_); +} + +SUNLinSolWrapper& SUNLinSolWrapper::operator=(SUNLinSolWrapper&& other +) noexcept { + std::swap(solver_, other.solver_); + std::swap(A_, other.A_); + return *this; } SUNLinearSolver SUNLinSolWrapper::get() const { return solver_; } @@ -31,19 +45,14 @@ int SUNLinSolWrapper::initialize() { return res; } -void SUNLinSolWrapper::setup(SUNMatrix A) const { - auto res = SUNLinSolSetup(solver_, A); +void SUNLinSolWrapper::setup() const { + auto res = SUNLinSolSetup(solver_, A_.get()); if (res != SUN_SUCCESS) throw AmiException("Solver setup failed with code %d", res); } -void SUNLinSolWrapper::setup(SUNMatrixWrapper const& A) const { - return setup(A.get()); -} - -int SUNLinSolWrapper::Solve(SUNMatrix A, N_Vector x, N_Vector b, realtype tol) - const { - return SUNLinSolSolve(solver_, A, x, b, tol); +int SUNLinSolWrapper::solve(N_Vector x, N_Vector b, realtype tol) const { + return SUNLinSolSolve(solver_, A_.get(), x, b, tol); } long SUNLinSolWrapper::getLastFlag() const { @@ -54,7 +63,7 @@ int SUNLinSolWrapper::space(long* lenrwLS, long* leniwLS) const { return SUNLinSolSpace(solver_, lenrwLS, leniwLS); } -SUNMatrix SUNLinSolWrapper::getMatrix() const { return nullptr; } +SUNMatrixWrapper& SUNLinSolWrapper::getMatrix() { return A_; } SUNNonLinSolWrapper::SUNNonLinSolWrapper(SUNNonlinearSolver sol) : solver(sol) {} @@ -153,24 +162,26 @@ void SUNNonLinSolWrapper::initialize() { ); } -SUNLinSolBand::SUNLinSolBand(N_Vector x, SUNMatrix A) +SUNLinSolBand::SUNLinSolBand(N_Vector x, SUNMatrixWrapper A) : SUNLinSolWrapper(SUNLinSol_Band(x, A, x->sunctx)) { if (!solver_) throw AmiException("Failed to create solver."); } SUNLinSolBand::SUNLinSolBand(AmiVector const& x, int ubw, int lbw) - : A_(SUNMatrixWrapper(x.getLength(), ubw, lbw, x.get_ctx())) { + : SUNLinSolWrapper( + nullptr, SUNMatrixWrapper(x.getLength(), ubw, lbw, x.get_ctx()) + ) { solver_ = SUNLinSol_Band(const_cast(x.getNVector()), A_, x.get_ctx()); if (!solver_) throw AmiException("Failed to create solver."); } -SUNMatrix SUNLinSolBand::getMatrix() const { return A_.get(); } - SUNLinSolDense::SUNLinSolDense(AmiVector const& x) - : A_(SUNMatrixWrapper(x.getLength(), x.getLength(), x.get_ctx())) { + : SUNLinSolWrapper( + nullptr, SUNMatrixWrapper(x.getLength(), x.getLength(), x.get_ctx()) + ) { solver_ = SUNLinSol_Dense( const_cast(x.getNVector()), A_, x.get_ctx() ); @@ -178,9 +189,7 @@ SUNLinSolDense::SUNLinSolDense(AmiVector const& x) throw AmiException("Failed to create solver."); } -SUNMatrix SUNLinSolDense::getMatrix() const { return A_.get(); } - -SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrix A) +SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrixWrapper A) : SUNLinSolWrapper(SUNLinSol_KLU(x, A, x->sunctx)) { if (!solver_) throw AmiException("Failed to create solver."); @@ -189,20 +198,20 @@ SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrix A) SUNLinSolKLU::SUNLinSolKLU( AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering ) - : A_(SUNMatrixWrapper( - x.getLength(), x.getLength(), nnz, sparsetype, x.get_ctx() - )) { - solver_ = SUNLinSol_KLU( - const_cast(x.getNVector()), A_, A_.get()->sunctx - ); + : SUNLinSolWrapper( + nullptr, + SUNMatrixWrapper( + x.getLength(), x.getLength(), nnz, sparsetype, x.get_ctx() + ) + ) { + solver_ + = SUNLinSol_KLU(const_cast(x.getNVector()), A_, x.get_ctx()); if (!solver_) throw AmiException("Failed to create solver."); setOrdering(ordering); } -SUNMatrix SUNLinSolKLU::getMatrix() const { return A_.get(); } - void SUNLinSolKLU::reInit(int nnz, int reinit_type) { int status = SUNLinSol_KLUReInit(solver_, A_, nnz, reinit_type); if (status != SUN_SUCCESS) @@ -422,8 +431,10 @@ int SUNNonLinSolFixedPoint::getSysFn(SUNNonlinSolSysFn* SysFn) const { #ifdef SUNDIALS_SUPERLUMT -SUNLinSolSuperLUMT::SUNLinSolSuperLUMT(N_Vector x, SUNMatrix A, int numThreads) - : SUNLinSolWrapper(SUNLinSol_SuperLUMT(x, A, numThreads)) { +SUNLinSolSuperLUMT::SUNLinSolSuperLUMT( + N_Vector x, SUNMatrixWrapper A, int numThreads +) + : SUNLinSolWrapper(SUNLinSol_SuperLUMT(x, A, numThreads), A) { if (!solver) throw AmiException("Failed to create solver."); } @@ -432,7 +443,10 @@ SUNLinSolSuperLUMT::SUNLinSolSuperLUMT( AmiVector const& x, int nnz, int sparsetype, SUNLinSolSuperLUMT::StateOrdering ordering ) - : A(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) { + : SUNLinSolWrapper( + nullptr, + SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype) + ) { int numThreads = 1; if (auto env = std::getenv("AMICI_SUPERLUMT_NUM_THREADS")) { numThreads = std::max(1, std::stoi(env)); @@ -449,7 +463,10 @@ SUNLinSolSuperLUMT::SUNLinSolSuperLUMT( AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering, int numThreads ) - : A(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) { + : SUNLinSolWrapper( + nullptr, + SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype) + ) { solver = SUNLinSol_SuperLUMT(x.getNVector(), A.get(), numThreads); if (!solver) throw AmiException("Failed to create solver."); @@ -457,8 +474,6 @@ SUNLinSolSuperLUMT::SUNLinSolSuperLUMT( setOrdering(ordering); } -SUNMatrix SUNLinSolSuperLUMT::getMatrix() const { return A.get(); } - void SUNLinSolSuperLUMT::setOrdering(StateOrdering ordering) { auto status = SUNLinSol_SuperLUMTSetOrdering(solver, static_cast(ordering));