Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Sep 28, 2024
1 parent 5980e49 commit b59e24f
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 92 deletions.
2 changes: 2 additions & 0 deletions include/amici/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ constexpr int AMICI_LSETUP_FAIL= -6;
constexpr int AMICI_RHSFUNC_FAIL= -8;
constexpr int AMICI_FIRST_RHSFUNC_ERR= -9;
constexpr int AMICI_CONSTR_FAIL= -15;
constexpr int AMICI_CVODES_CONSTR_FAIL= -15;
constexpr int AMICI_IDAS_CONSTR_FAIL= -11;
constexpr int AMICI_ILL_INPUT= -22;
constexpr int AMICI_ERROR= -99;
constexpr int AMICI_NO_STEADY_STATE= -81;
Expand Down
2 changes: 1 addition & 1 deletion include/amici/model_dae.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Model_DAE : public Model {
model_dimensions, simulation_parameters, o2mode, idlist, z2event,
state_independent_events
) {
auto sunctx = derived_state_.x_pos_tmp_.sunctx_;
SUNContext sunctx = derived_state_.sunctx_;
derived_state_.M_ = SUNMatrixWrapper(nx_solver, nx_solver, sunctx);
auto M_nnz = static_cast<sunindextype>(
std::reduce(idlist.begin(), idlist.end())
Expand Down
51 changes: 34 additions & 17 deletions include/amici/vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,29 +40,31 @@ class AmiVector {
*/
AmiVector() = default;

/** Creates an std::vector<realtype> and attaches the
/**
* @brief Construct empty vector of given size
*
* Creates an std::vector<realtype> and attaches the
* data pointer to a newly created N_Vector_Serial.
* Using N_VMake_Serial ensures that the N_Vector
* module does not try to deallocate the data vector
* when calling N_VDestroy_Serial
* @brief empty constructor
* @param length number of elements in vector
* @param sunctx SUNDIALS context
*/
explicit AmiVector(long int const length, SUNContext sunctx)
: sunctx_(sunctx)
, vec_(static_cast<decltype(vec_)::size_type>(length), 0.0)
: vec_(static_cast<decltype(vec_)::size_type>(length), 0.0)
, nvec_(N_VMake_Serial(length, vec_.data(), sunctx)) {}

/** Moves data from std::vector and constructs an nvec that points to the
/**
* @brief Constructor from std::vector
*
* Moves data from std::vector and constructs an nvec that points to the
* data
* @brief constructor from std::vector,
* @param rvec vector from which the data will be moved
* @param sunctx SUNDIALS context
*/
explicit AmiVector(std::vector<realtype> rvec, SUNContext sunctx)
: sunctx_(sunctx)
, vec_(std::move(rvec))
: vec_(std::move(rvec))
, nvec_(N_VMake_Serial(
gsl::narrow<long int>(vec_.size()), vec_.data(), sunctx
)) {}
Expand All @@ -80,8 +82,7 @@ class AmiVector {
* @param vold vector from which the data will be copied
*/
AmiVector(AmiVector const& vold)
: sunctx_(vold.sunctx_)
, vec_(vold.vec_) {
: vec_(vold.vec_) {
if (vold.nvec_ == nullptr) {
nvec_ = nullptr;
return;
Expand All @@ -97,10 +98,9 @@ class AmiVector {
* @param other vector from which the data will be moved
*/
AmiVector(AmiVector&& other) noexcept
: sunctx_(other.sunctx_)
: vec_(std::move(other.vec_))
, nvec_(nullptr) {
vec_ = std::move(other.vec_);
synchroniseNVector();
synchroniseNVector(other.get_ctx());
}

/**
Expand Down Expand Up @@ -249,8 +249,23 @@ class AmiVector {
Archive& ar, AmiVector& s, unsigned int version
);

/** SUNDIALS context */
SUNContext sunctx_{nullptr};
/**
* @brief Get SUNContext
* @return The current SUNContext or nullptr, if this AmiVector is empty
*/
SUNContext get_ctx() const {
return nvec_ == nullptr ? nullptr : nvec_->sunctx;
}

/**
* @brief Set SUNContext
*
* If this AmiVector is non-empty, changes the current SUNContext of the
* associated N_Vector. If empty, do nothing.
*
* @param ctx SUNDIALS context to set
*/
void set_ctx(SUNContext ctx) { if(nvec_) nvec_->sunctx = ctx; }

private:
/** main data storage */
Expand All @@ -261,8 +276,9 @@ class AmiVector {

/**
* @brief reconstructs nvec such that data pointer points to vec data array
* @param sunctx SUNDIALS context
*/
void synchroniseNVector();
void synchroniseNVector(SUNContext sunctx);
};

/**
Expand Down Expand Up @@ -395,10 +411,11 @@ class AmiVectorArray {
*/
void copy(AmiVectorArray const& other);


private:
/** SUNDIALS context */
SUNContext sunctx_{nullptr};

private:
/** main data storage */
std::vector<AmiVector> vec_array_;

Expand Down
2 changes: 2 additions & 0 deletions src/amici.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ std::map<int, std::string> simulation_status_to_str_map = {
{AMICI_CONV_FAILURE, "AMICI_CONV_FAILURE"},
{AMICI_FIRST_RHSFUNC_ERR, "AMICI_FIRST_RHSFUNC_ERR"},
{AMICI_CONSTR_FAIL, "AMICI_CONSTR_FAIL"},
{AMICI_CVODES_CONSTR_FAIL, "AMICI_CVODES_CONSTR_FAIL"},
{AMICI_IDAS_CONSTR_FAIL, "AMICI_IDAS_CONSTR_FAIL"},
{AMICI_RHSFUNC_FAIL, "AMICI_RHSFUNC_FAIL"},
{AMICI_ILL_INPUT, "AMICI_ILL_INPUT"},
{AMICI_ERROR, "AMICI_ERROR"},
Expand Down
67 changes: 14 additions & 53 deletions src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "amici/exception.h"
#include "amici/model.h"
#include "amici/symbolic_functions.h"
#include "amici/amici.h"

#include <sundials/sundials_context.h>

Expand All @@ -12,78 +13,41 @@

namespace amici {

/* Error handler passed to SUNDIALS. */
void wrapErrHandlerFn(
[[maybe_unused]] int line, char const* func, char const* file,
char const* msg, SUNErrCode err_code, void* err_user_data,
[[maybe_unused]] SUNContext sunctx
) {
constexpr int BUF_SIZE = 250;
char buffer[BUF_SIZE];
char buffid[BUF_SIZE];

char msg_buffer[BUF_SIZE];
char id_buffer[BUF_SIZE];
static_assert(
std::is_same<SUNErrCode, int>::value, "Must update format string"
);
// for debug builds, include full file path and line numbers
#ifdef NDEBUG
snprintf(buffer, BUF_SIZE, "%s:%d: %s (%d)", file, line, msg, err_code);
snprintf(msg_buffer, BUF_SIZE, "%s:%d: %s (%d)", file, line, msg, err_code);
#else
snprintf(buffer, BUF_SIZE, "%s", msg);
#endif
// we need a matlab-compatible message ID
// i.e. colon separated and only [A-Za-z0-9_]
// see https://mathworks.com/help/matlab/ref/mexception.html
std::filesystem::path path(file);
auto file_stem = path.stem().string();

switch (err_code) {
case 99:
snprintf(buffid, BUF_SIZE, "%s:%s:WARNING", file_stem.c_str(), func);
break;

case AMICI_TOO_MUCH_WORK:
snprintf(
buffid, BUF_SIZE, "%s:%s:TOO_MUCH_WORK", file_stem.c_str(), func
);
break;

case AMICI_TOO_MUCH_ACC:
snprintf(
buffid, BUF_SIZE, "%s:%s:TOO_MUCH_ACC", file_stem.c_str(), func
);
break;

case AMICI_ERR_FAILURE:
snprintf(
buffid, BUF_SIZE, "%s:%s:ERR_FAILURE", file_stem.c_str(), func
);
break;

case AMICI_CONV_FAILURE:
snprintf(
buffid, BUF_SIZE, "%s:%s:CONV_FAILURE", file_stem.c_str(), func
);
break;

case AMICI_RHSFUNC_FAIL:
snprintf(
buffid, BUF_SIZE, "%s:%s:RHSFUNC_FAIL", file_stem.c_str(), func
);
break;

case AMICI_FIRST_RHSFUNC_ERR:
snprintf(
buffid, BUF_SIZE, "%s:%s:FIRST_RHSFUNC_ERR", file_stem.c_str(), func
);
break;
default:
snprintf(buffid, BUF_SIZE, "%s:%s:OTHER", file_stem.c_str(), func);
break;
}
auto err_code_str = simulation_status_to_str(err_code);
snprintf(id_buffer, BUF_SIZE, "%s:%s:%s", file_stem.c_str(), func, err_code_str.c_str());

if (!err_user_data) {
throw std::runtime_error("eh_data unset");
}

auto solver = static_cast<Solver const*>(err_user_data);
if (solver->logger)
solver->logger->log(LogSeverity::debug, buffid, buffer);
solver->logger->log(LogSeverity::debug, id_buffer, msg_buffer);
}

Solver::Solver(Solver const& other)
Expand Down Expand Up @@ -127,11 +91,8 @@ Solver::Solver(Solver const& other)
, max_step_size_(other.max_step_size_)
, maxstepsB_(other.maxstepsB_)
, sensi_(other.sensi_) {
// AmiVector.setContext()... check for nullptr
if (constraints_.data()) {
constraints_.sunctx_ = sunctx_;
constraints_.getNVector()->sunctx = sunctx_;
}
// update to our own context
constraints_.set_ctx(sunctx_);
}

SUNContext Solver::getSunContext() const { return sunctx_; }
Expand Down
5 changes: 2 additions & 3 deletions src/solver_idas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ static_assert(
amici::AMICI_LSETUP_FAIL == IDA_LSETUP_FAIL,
"AMICI_LSETUP_FAIL != IDA_LSETUP_FAIL"
);
// FIXME: this does not match CVODE, we need separate return values
// static_assert(amici::AMICI_CONSTR_FAIL == IDA_CONSTR_FAIL, "AMICI_CONSTR_FAIL
// != IDA_CONSTR_FAIL");
// This does not match the CVODE code, we need separate return values
static_assert(amici::AMICI_IDAS_CONSTR_FAIL == IDA_CONSTR_FAIL, "AMICI_IDAS_CONSTR_FAIL != IDA_CONSTR_FAIL");

/*
* The following static members are callback function to IDAS.
Expand Down
18 changes: 9 additions & 9 deletions src/sundials_linsol_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,19 @@ SUNLinSolBand::SUNLinSolBand(N_Vector x, SUNMatrix A)
}

SUNLinSolBand::SUNLinSolBand(AmiVector const& x, int ubw, int lbw)
: A_(SUNMatrixWrapper(x.getLength(), ubw, lbw, x.sunctx_)) {
: A_(SUNMatrixWrapper(x.getLength(), ubw, lbw, x.get_ctx())) {
solver_
= SUNLinSol_Band(const_cast<N_Vector>(x.getNVector()), A_, x.sunctx_);
= SUNLinSol_Band(const_cast<N_Vector>(x.getNVector()), A_, x.get_ctx());
if (!solver_)
throw AmiException("Failed to create solver.");
}

SUNMatrix SUNLinSolBand::getMatrix() const { return A_.get(); }

SUNLinSolDense::SUNLinSolDense(AmiVector const& x)
: A_(SUNMatrixWrapper(x.getLength(), x.getLength(), x.sunctx_)) {
: A_(SUNMatrixWrapper(x.getLength(), x.getLength(), x.get_ctx())) {
solver_
= SUNLinSol_Dense(const_cast<N_Vector>(x.getNVector()), A_, x.sunctx_);
= SUNLinSol_Dense(const_cast<N_Vector>(x.getNVector()), A_, x.get_ctx());
if (!solver_)
throw AmiException("Failed to create solver.");
}
Expand All @@ -189,7 +189,7 @@ SUNLinSolKLU::SUNLinSolKLU(
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering
)
: A_(SUNMatrixWrapper(
x.getLength(), x.getLength(), nnz, sparsetype, x.sunctx_
x.getLength(), x.getLength(), nnz, sparsetype, x.get_ctx()
)) {
solver_ = SUNLinSol_KLU(
const_cast<N_Vector>(x.getNVector()), A_, A_.get()->sunctx
Expand Down Expand Up @@ -250,7 +250,7 @@ SUNLinSolSPBCGS::SUNLinSolSPBCGS(N_Vector x, int pretype, int maxl)

SUNLinSolSPBCGS::SUNLinSolSPBCGS(AmiVector const& x, int pretype, int maxl) {
solver_ = SUNLinSol_SPBCGS(
const_cast<N_Vector>(x.getNVector()), pretype, maxl, x.sunctx_
const_cast<N_Vector>(x.getNVector()), pretype, maxl, x.get_ctx()
);
if (!solver_)
throw AmiException("Failed to create solver.");
Expand Down Expand Up @@ -284,7 +284,7 @@ N_Vector SUNLinSolSPBCGS::getResid() const {

SUNLinSolSPFGMR::SUNLinSolSPFGMR(AmiVector const& x, int pretype, int maxl)
: SUNLinSolWrapper(SUNLinSol_SPFGMR(
const_cast<N_Vector>(x.getNVector()), pretype, maxl, x.sunctx_
const_cast<N_Vector>(x.getNVector()), pretype, maxl, x.get_ctx()
)) {
if (!solver_)
throw AmiException("Failed to create solver.");
Expand Down Expand Up @@ -318,7 +318,7 @@ N_Vector SUNLinSolSPFGMR::getResid() const {

SUNLinSolSPGMR::SUNLinSolSPGMR(AmiVector const& x, int pretype, int maxl)
: SUNLinSolWrapper(SUNLinSol_SPGMR(
const_cast<N_Vector>(x.getNVector()), pretype, maxl, x.sunctx_
const_cast<N_Vector>(x.getNVector()), pretype, maxl, x.get_ctx()
)) {
if (!solver_)
throw AmiException("Failed to create solver.");
Expand Down Expand Up @@ -358,7 +358,7 @@ SUNLinSolSPTFQMR::SUNLinSolSPTFQMR(N_Vector x, int pretype, int maxl)

SUNLinSolSPTFQMR::SUNLinSolSPTFQMR(AmiVector const& x, int pretype, int maxl) {
solver_ = SUNLinSol_SPTFQMR(
const_cast<N_Vector>(x.getNVector()), pretype, maxl, x.sunctx_
const_cast<N_Vector>(x.getNVector()), pretype, maxl, x.get_ctx()
);
if (!solver_)
throw AmiException("Failed to create solver.");
Expand Down
8 changes: 3 additions & 5 deletions src/sundials_matrix_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ void SUNMatrixWrapper::reallocate(sunindextype NNZ) {

if (int ret = SUNSparseMatrix_Reallocate(matrix_, NNZ) != SUN_SUCCESS)
throw std::runtime_error(
"SUNSparseMatrix_Reallocate failed with "
"error code "
"SUNSparseMatrix_Reallocate failed with error code "
+ std::to_string(ret) + "."
);

Expand All @@ -165,8 +164,7 @@ void SUNMatrixWrapper::realloc() {
"CSR_MAT.");
if (int ret = SUNSparseMatrix_Realloc(matrix_) != SUN_SUCCESS)
throw std::runtime_error(
"SUNSparseMatrix_Realloc failed with "
"error code "
"SUNSparseMatrix_Realloc failed with error code "
+ std::to_string(ret) + "."
);

Expand Down Expand Up @@ -510,7 +508,7 @@ void SUNMatrixWrapper::sparse_sum(std::vector<SUNMatrixWrapper> const& mats) {

for (acol = 0; acol < columns(); acol++) {
set_indexptr(acol, nnz); /* column j of A starts here */
for (auto& mat : mats)
for (auto const& mat : mats)
nnz = mat.scatter(
acol, 1.0, w.data(), gsl::make_span(x), acol + 1, this, nnz
);
Expand Down
7 changes: 3 additions & 4 deletions src/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
namespace amici {

AmiVector& AmiVector::operator=(AmiVector const& other) {
sunctx_ = other.sunctx_;
vec_ = other.vec_;
synchroniseNVector();
synchroniseNVector(other.get_ctx());
return *this;
}

Expand Down Expand Up @@ -56,13 +55,13 @@ void AmiVector::copy(AmiVector const& other) {
std::copy(other.vec_.begin(), other.vec_.end(), vec_.begin());
}

void AmiVector::synchroniseNVector() {
void AmiVector::synchroniseNVector(SUNContext sunctx) {
if (nvec_)
N_VDestroy_Serial(nvec_);
nvec_ = vec_.empty()
? nullptr
: N_VMake_Serial(
gsl::narrow<long int>(vec_.size()), vec_.data(), sunctx_
gsl::narrow<long int>(vec_.size()), vec_.data(), sunctx
);
}

Expand Down

0 comments on commit b59e24f

Please sign in to comment.