diff --git a/include/amici/newton_solver.h b/include/amici/newton_solver.h index 2e8b2f6573..0b3ce7e630 100644 --- a/include/amici/newton_solver.h +++ b/include/amici/newton_solver.h @@ -2,11 +2,8 @@ #define amici_newton_solver_h #include "amici/solver.h" -#include "amici/sundials_matrix_wrapper.h" #include "amici/vector.h" -#include - namespace amici { class Model; @@ -27,18 +24,13 @@ class NewtonSolver { * model * * @param model pointer to the model object + * @param linsol_type type of linear solver to use */ - explicit NewtonSolver(Model const& model); + explicit NewtonSolver(Model const& model, LinearSolver linsol_type); - /** - * @brief Factory method to create a NewtonSolver based on linsolType - * - * @param simulationSolver solver with settings - * @param model pointer to the model instance - * @return solver NewtonSolver according to the specified linsolType - */ - static std::unique_ptr - getSolver(Solver const& simulationSolver, Model const& model); + NewtonSolver(NewtonSolver const&) = delete; + + NewtonSolver& operator=(NewtonSolver const& other) = delete; /** * @brief Computes the solution of one Newton iteration @@ -68,8 +60,7 @@ class NewtonSolver { * @param model pointer to the model instance * @param state current simulation state */ - virtual void prepareLinearSystem(Model& model, SimulationState const& state) - = 0; + void prepareLinearSystem(Model& model, SimulationState const& state); /** * Writes the Jacobian (JB) for the Newton iteration and passes it to the @@ -78,9 +69,7 @@ class NewtonSolver { * @param model pointer to the model instance * @param state current simulation state */ - virtual void - prepareLinearSystemB(Model& model, SimulationState const& state) - = 0; + void prepareLinearSystemB(Model& model, SimulationState const& state); /** * @brief Solves the linear system for the Newton step @@ -88,28 +77,23 @@ class NewtonSolver { * @param rhs containing the RHS of the linear system, will be * overwritten by solution to the linear system */ - virtual void solveLinearSystem(AmiVector& rhs) = 0; + void solveLinearSystem(AmiVector& rhs); /** * @brief Reinitialize the linear solver * */ - virtual void reinitialize() = 0; + void reinitialize(); /** - * @brief Checks whether linear system is singular + * @brief Checks whether the linear system is singular * - * @param model pointer to the model instance - * @param state current simulation state * @return boolean indicating whether the linear system is singular * (condition number < 1/machine precision) */ - virtual bool is_singular(Model& model, SimulationState const& state) const - = 0; - - virtual ~NewtonSolver() = default; + bool is_singular(Model& model, SimulationState const& state) const; - protected: + private: /** dummy rhs, used as dummy argument when computing J and JB */ AmiVector xdot_; /** dummy state, attached to linear solver */ @@ -119,88 +103,9 @@ class NewtonSolver { /** dummy differential adjoint state, used as dummy argument when computing * JB */ AmiVector dxB_; -}; - -/** - * @brief The NewtonSolverDense provides access to the dense linear solver for - * the Newton method. - */ - -class NewtonSolverDense : public NewtonSolver { - - public: - /** - * @brief constructor for sparse solver - * - * @param model model instance that provides problem dimensions - */ - explicit NewtonSolverDense(Model const& model); - - NewtonSolverDense(NewtonSolverDense const&) = delete; - - NewtonSolverDense& operator=(NewtonSolverDense const& other) = delete; - - ~NewtonSolverDense() override; - - void solveLinearSystem(AmiVector& rhs) override; - - void - prepareLinearSystem(Model& model, SimulationState const& state) override; - - void - prepareLinearSystemB(Model& model, SimulationState const& state) override; - - void reinitialize() override; - - bool is_singular(Model& model, SimulationState const& state) const override; - - private: - /** temporary storage of Jacobian */ - SUNMatrixWrapper Jtmp_; - - /** dense linear solver */ - SUNLinearSolver linsol_{nullptr}; -}; - -/** - * @brief The NewtonSolverSparse provides access to the sparse linear solver for - * the Newton method. - */ - -class NewtonSolverSparse : public NewtonSolver { - - public: - /** - * @brief constructor for dense solver - * - * @param model model instance that provides problem dimensions - */ - explicit NewtonSolverSparse(Model const& model); - - NewtonSolverSparse(NewtonSolverSparse const&) = delete; - - NewtonSolverSparse& operator=(NewtonSolverSparse const& other) = delete; - - ~NewtonSolverSparse() override; - - void solveLinearSystem(AmiVector& rhs) override; - - void - prepareLinearSystem(Model& model, SimulationState const& state) override; - - void - prepareLinearSystemB(Model& model, SimulationState const& state) override; - - bool is_singular(Model& model, SimulationState const& state) const override; - - void reinitialize() override; - - private: - /** temporary storage of Jacobian */ - SUNMatrixWrapper Jtmp_; - /** sparse linear solver */ - SUNLinearSolver linsol_{nullptr}; + /** linear solver */ + std::unique_ptr linsol_; }; } // namespace amici diff --git a/include/amici/steadystateproblem.h b/include/amici/steadystateproblem.h index da6af249ac..f680b6c128 100644 --- a/include/amici/steadystateproblem.h +++ b/include/amici/steadystateproblem.h @@ -452,7 +452,7 @@ class SteadystateProblem { realtype rtol_quad_{NAN}; /** newton solver */ - std::unique_ptr newton_solver_{nullptr}; + NewtonSolver newton_solver_; /** damping factor flag */ NewtonDampingFactorMode damping_factor_mode_{NewtonDampingFactorMode::on}; diff --git a/include/amici/sundials_linsol_wrapper.h b/include/amici/sundials_linsol_wrapper.h index d8fa1e72b8..d9d2105c38 100644 --- a/include/amici/sundials_linsol_wrapper.h +++ b/include/amici/sundials_linsol_wrapper.h @@ -195,7 +195,8 @@ class SUNLinSolKLU : public SUNLinSolWrapper { * @param ordering */ SUNLinSolKLU( - AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering = StateOrdering::COLAMD + AmiVector const& x, int nnz, int sparsetype, + StateOrdering ordering = StateOrdering::COLAMD ); /** @@ -214,6 +215,14 @@ class SUNLinSolKLU : public SUNLinSolWrapper { * @param ordering */ void setOrdering(StateOrdering ordering); + + /** + * @brief Checks whether the linear system is singular + * + * @return boolean indicating whether the linear system is singular + * (condition number < 1/machine precision) + */ + bool is_singular() const; }; #ifdef SUNDIALS_SUPERLUMT diff --git a/src/newton_solver.cpp b/src/newton_solver.cpp index 72bf0a3d64..b37e3d0b81 100644 --- a/src/newton_solver.cpp +++ b/src/newton_solver.cpp @@ -10,56 +10,24 @@ namespace amici { -NewtonSolver::NewtonSolver(Model const& model) +NewtonSolver::NewtonSolver(Model const& model, LinearSolver linsol_type) : xdot_(model.nx_solver) , x_(model.nx_solver) , xB_(model.nJ * model.nx_solver) - , dxB_(model.nJ * model.nx_solver) {} + , dxB_(model.nJ * model.nx_solver) { -std::unique_ptr -NewtonSolver::getSolver(Solver const& simulationSolver, Model const& model) { - - std::unique_ptr solver; - - switch (simulationSolver.getLinearSolver()) { - - /* DIRECT SOLVERS */ + switch (linsol_type) { case LinearSolver::dense: - solver.reset(new NewtonSolverDense(model)); + linsol_.reset(new SUNLinSolDense(x_)); break; - - case LinearSolver::band: - throw NewtonFailure(AMICI_NOT_IMPLEMENTED, "getSolver"); - - case LinearSolver::LAPACKDense: - throw NewtonFailure(AMICI_NOT_IMPLEMENTED, "getSolver"); - - case LinearSolver::LAPACKBand: - throw NewtonFailure(AMICI_NOT_IMPLEMENTED, "getSolver"); - - case LinearSolver::diag: - throw NewtonFailure(AMICI_NOT_IMPLEMENTED, "getSolver"); - - /* ITERATIVE SOLVERS */ - case LinearSolver::SPGMR: - throw NewtonFailure(AMICI_NOT_IMPLEMENTED, "getSolver"); - - case LinearSolver::SPBCG: - throw NewtonFailure(AMICI_NOT_IMPLEMENTED, "getSolver"); - - case LinearSolver::SPTFQMR: - throw NewtonFailure(AMICI_NOT_IMPLEMENTED, "getSolver"); - - /* SPARSE SOLVERS */ - case LinearSolver::SuperLUMT: - throw NewtonFailure(AMICI_NOT_IMPLEMENTED, "getSolver"); case LinearSolver::KLU: - solver.reset(new NewtonSolverSparse(model)); + linsol_.reset(new SUNLinSolKLU(x_, model.nnz, CSC_MAT)); break; default: - throw NewtonFailure(AMICI_NOT_IMPLEMENTED, "getSolver"); + throw NewtonFailure( + AMICI_NOT_IMPLEMENTED, "Unknown linear solver type" + ); } - return solver; } void NewtonSolver::getStep( @@ -105,144 +73,59 @@ void NewtonSolver::computeNewtonSensis( } } -NewtonSolverDense::NewtonSolverDense(Model const& model) - : NewtonSolver(model) - , Jtmp_(model.nx_solver, model.nx_solver) - , linsol_(SUNLinSol_Dense(x_.getNVector(), Jtmp_)) { - auto status = SUNLinSolInitialize_Dense(linsol_); - if (status != SUNLS_SUCCESS) - throw NewtonFailure(status, "SUNLinSolInitialize_Dense"); -} - -void NewtonSolverDense::prepareLinearSystem( +void NewtonSolver::prepareLinearSystem( Model& model, SimulationState const& state ) { - model.fJ(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_); - Jtmp_.refresh(); - auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_); - if (status != SUNLS_SUCCESS) - throw NewtonFailure(status, "SUNLinSolSetup_Dense"); + auto& J = linsol_->getMatrix(); + if (J.matrix_id() == SUNMATRIX_SPARSE) { + model.fJSparse(state.t, 0.0, state.x, state.dx, xdot_, J); + } else { + model.fJ(state.t, 0.0, state.x, state.dx, xdot_, J); + } + J.refresh(); + linsol_->setup(); } -void NewtonSolverDense::prepareLinearSystemB( +void NewtonSolver::prepareLinearSystemB( Model& model, SimulationState const& state ) { - model.fJB(state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_); - Jtmp_.refresh(); - auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_); - if (status != SUNLS_SUCCESS) - throw NewtonFailure(status, "SUNLinSolSetup_Dense"); + auto& J = linsol_->getMatrix(); + if (J.matrix_id() == SUNMATRIX_SPARSE) { + model.fJSparseB(state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, J); + } else { + model.fJB(state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, J); + } + J.refresh(); + linsol_->setup(); } -void NewtonSolverDense::solveLinearSystem(AmiVector& rhs) { - auto status = SUNLinSolSolve_Dense( - linsol_, Jtmp_, rhs.getNVector(), rhs.getNVector(), 0.0 - ); - Jtmp_.refresh(); +void NewtonSolver::solveLinearSystem(AmiVector& rhs) { // last argument is tolerance and does not have any influence on result - + auto status = linsol_->solve(rhs.getNVector(), rhs.getNVector(), 0.0); if (status != SUNLS_SUCCESS) - throw NewtonFailure(status, "SUNLinSolSolve_Dense"); + throw NewtonFailure(status, "SUNLinSolSolve"); } -void NewtonSolverDense::reinitialize() { - /* dense solver does not need reinitialization */ +void NewtonSolver::reinitialize() { + // dense solver does not need reinitialization + + if (auto s = dynamic_cast(linsol_.get())) { + s->reInit(s->getMatrix().capacity(), SUNKLU_REINIT_PARTIAL); + } }; -bool NewtonSolverDense::is_singular(Model& model, SimulationState const& state) +bool NewtonSolver::is_singular(Model& model, SimulationState const& state) const { + if (auto s = dynamic_cast(linsol_.get())) { + return s->is_singular(); + } + // dense solver doesn't have any implementation for rcond/condest, so use // sparse solver interface, not the most efficient solution, but who is // concerned about speed and used the dense solver anyways ¯\_(ツ)_/¯ - NewtonSolverSparse sparse_solver(model); + NewtonSolver sparse_solver(model, LinearSolver::KLU); sparse_solver.prepareLinearSystem(model, state); return sparse_solver.is_singular(model, state); } -NewtonSolverDense::~NewtonSolverDense() { - if (linsol_) - SUNLinSolFree_Dense(linsol_); -} - -NewtonSolverSparse::NewtonSolverSparse(Model const& model) - : NewtonSolver(model) - , Jtmp_(model.nx_solver, model.nx_solver, model.nnz, CSC_MAT) - , linsol_(SUNKLU(x_.getNVector(), Jtmp_)) { - auto status = SUNLinSolInitialize_KLU(linsol_); - if (status != SUNLS_SUCCESS) - throw NewtonFailure(status, "SUNLinSolInitialize_KLU"); -} - -void NewtonSolverSparse::prepareLinearSystem( - Model& model, SimulationState const& state -) { - /* Get sparse Jacobian */ - model.fJSparse(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_); - Jtmp_.refresh(); - auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_); - if (status != SUNLS_SUCCESS) - throw NewtonFailure(status, "SUNLinSolSetup_KLU"); -} - -void NewtonSolverSparse::prepareLinearSystemB( - Model& model, SimulationState const& state -) { - /* Get sparse Jacobian */ - model.fJSparseB(state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_); - Jtmp_.refresh(); - auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_); - if (status != SUNLS_SUCCESS) - throw NewtonFailure(status, "SUNLinSolSetup_KLU"); -} - -void NewtonSolverSparse::solveLinearSystem(AmiVector& rhs) { - /* Pass pointer to the linear solver */ - auto status = SUNLinSolSolve_KLU( - linsol_, Jtmp_, rhs.getNVector(), rhs.getNVector(), 0.0 - ); - // last argument is tolerance and does not have any influence on result - - if (status != SUNLS_SUCCESS) - throw NewtonFailure(status, "SUNLinSolSolve_KLU"); -} - -void NewtonSolverSparse::reinitialize() { - /* partial reinitialization, don't need to reallocate Jtmp_ */ - auto status = SUNLinSol_KLUReInit( - linsol_, Jtmp_, Jtmp_.capacity(), SUNKLU_REINIT_PARTIAL - ); - if (status != SUNLS_SUCCESS) - throw NewtonFailure(status, "SUNLinSol_KLUReInit"); -} - -bool NewtonSolverSparse:: - is_singular(Model& /*model*/, SimulationState const& /*state*/) const { - // adapted from SUNLinSolSetup_KLU in sunlinsol/klu/sunlinsol_klu.c - auto content = (SUNLinearSolverContent_KLU)(linsol_->content); - // first cheap check via rcond - auto status - = sun_klu_rcond(content->symbolic, content->numeric, &content->common); - if (status == 0) - throw NewtonFailure(content->last_flag, "sun_klu_rcond"); - - auto precision = std::numeric_limits::epsilon(); - - if (content->common.rcond < precision) { - // cheap check indicates singular, expensive check via condest - status = sun_klu_condest( - SM_INDEXPTRS_S(Jtmp_.get()), SM_DATA_S(Jtmp_.get()), - content->symbolic, content->numeric, &content->common - ); - if (status == 0) - throw NewtonFailure(content->last_flag, "sun_klu_rcond"); - return content->common.condest > 1.0 / precision; - } - return false; -} - -NewtonSolverSparse::~NewtonSolverSparse() { - if (linsol_) - SUNLinSolFree_KLU(linsol_); -} - } // namespace amici diff --git a/src/steadystateproblem.cpp b/src/steadystateproblem.cpp index 5feb319d9a..663d83beba 100644 --- a/src/steadystateproblem.cpp +++ b/src/steadystateproblem.cpp @@ -45,7 +45,7 @@ SteadystateProblem::SteadystateProblem(Solver const& solver, Model const& model) , rtol_sensi_(solver.getRelativeToleranceSteadyStateSensi()) , atol_quad_(solver.getAbsoluteToleranceQuadratures()) , rtol_quad_(solver.getRelativeToleranceQuadratures()) - , newton_solver_(NewtonSolver::getSolver(solver, model)) + , newton_solver_(NewtonSolver(model, solver.getLinearSolver())) , damping_factor_mode_(solver.getNewtonDampingFactorMode()) , damping_factor_lower_bound_(solver.getNewtonDampingFactorLowerBound()) , newton_step_conv_(solver.getNewtonStepSteadyStateCheck()) @@ -87,7 +87,7 @@ void SteadystateProblem::workSteadyStateProblem( try { /* this might still fail, if the Jacobian is singular and simulation did not find a steady state */ - newton_solver_->computeNewtonSensis(state_.sx, model, state_); + newton_solver_.computeNewtonSensis(state_.sx, model, state_); } catch (NewtonFailure const&) { throw AmiException( "Steady state sensitivity computation failed due " @@ -250,7 +250,7 @@ void SteadystateProblem::findSteadyStateBySimulation( void SteadystateProblem::initializeForwardProblem( int it, Solver const& solver, Model& model ) { - newton_solver_->reinitialize(); + newton_solver_.reinitialize(); /* process solver handling for pre- or postequilibration */ if (it == -1) { /* solver was not run before, set up everything */ @@ -279,7 +279,7 @@ void SteadystateProblem::initializeForwardProblem( bool SteadystateProblem::initializeBackwardProblem( Solver const& solver, Model& model, BackwardProblem const* bwd ) { - newton_solver_->reinitialize(); + newton_solver_.reinitialize(); /* note that state_ is still set from forward run */ if (bwd) { /* preequilibration */ @@ -359,8 +359,8 @@ void SteadystateProblem::getQuadratureByLinSolve(Model& model) { /* try to solve the linear system */ try { /* compute integral over xB and write to xQ */ - newton_solver_->prepareLinearSystemB(model, state_); - newton_solver_->solveLinearSystem(xQ_); + newton_solver_.prepareLinearSystemB(model, state_); + newton_solver_.solveLinearSystem(xQ_); /* Compute the quadrature as the inner product xQ * dxdotdp */ computeQBfromQ(model, xQ_, xQB_); /* set flag that quadratures is available (for processing in rdata) */ @@ -583,7 +583,7 @@ realtype SteadystateProblem::getWrmsFSA(Model& model) { state_.t, state_.x, state_.dx, ip, state_.sx[ip], state_.dx, xdot_ ); if (newton_step_conv_) - newton_solver_->solveLinearSystem(xdot_); + newton_solver_.solveLinearSystem(xdot_); wrms = getWrmsNorm( state_.sx[ip], xdot_, steadystate_mask_, atol_sensi_, rtol_sensi_, ewt_ @@ -887,7 +887,7 @@ void SteadystateProblem::getNewtonStep(Model& model) { return; updateRightHandSide(model); delta_.copy(xdot_); - newton_solver_->getStep(delta_, model, state_); + newton_solver_.getStep(delta_, model, state_); delta_updated_ = true; } } // namespace amici diff --git a/src/sundials_linsol_wrapper.cpp b/src/sundials_linsol_wrapper.cpp index 7808620227..62bd5b1808 100644 --- a/src/sundials_linsol_wrapper.cpp +++ b/src/sundials_linsol_wrapper.cpp @@ -25,7 +25,8 @@ SUNLinSolWrapper::SUNLinSolWrapper(SUNLinSolWrapper&& other) noexcept { std::swap(A_, other.A_); } -SUNLinSolWrapper& SUNLinSolWrapper::operator=(SUNLinSolWrapper&& other) noexcept { +SUNLinSolWrapper& SUNLinSolWrapper::operator=(SUNLinSolWrapper&& other +) noexcept { std::swap(solver_, other.solver_); std::swap(A_, other.A_); return *this; @@ -215,6 +216,30 @@ void SUNLinSolKLU::setOrdering(StateOrdering ordering) { throw AmiException("SUNLinSol_KLUSetOrdering failed with %d", status); } +bool SUNLinSolKLU::is_singular() const { + // adapted from SUNLinSolSetup_KLU in sunlinsol/klu/sunlinsol_klu.c + auto content = (SUNLinearSolverContent_KLU)(solver_->content); + // first cheap check via rcond + auto status + = sun_klu_rcond(content->symbolic, content->numeric, &content->common); + if (status == 0) + throw AmiException("sun_klu_rcond: %d", content->last_flag); + + auto precision = std::numeric_limits::epsilon(); + + if (content->common.rcond < precision) { + // cheap check indicates singular, expensive check via condest + status = sun_klu_condest( + SM_INDEXPTRS_S(A_.get()), SM_DATA_S(A_.get()), content->symbolic, + content->numeric, &content->common + ); + if (status == 0) + throw AmiException("sun_klu_condest: %d", content->last_flag); + return content->common.condest > 1.0 / precision; + } + return false; +} + SUNLinSolPCG::SUNLinSolPCG(N_Vector y, int pretype, int maxl) : SUNLinSolWrapper(SUNLinSol_PCG(y, pretype, maxl)) { if (!solver_)